diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index fbc89e98847bf..d29af00affb98 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -864,6 +864,14 @@ captureJVMException <- function(e, method) { # Extract the first message of JVM exception. first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] stop(paste0(rmsg, "no such table - ", first), call. = FALSE) + } else if (any(grep("org.apache.spark.sql.catalyst.parser.ParseException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.parser.ParseException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "parse error - ", first), call. = FALSE) } else { stop(stacktrace, call. = FALSE) } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 2cef7191d4f2a..1a3d6df437d7e 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -150,7 +150,12 @@ test_that("structField type strings", { binary = "BinaryType", boolean = "BooleanType", timestamp = "TimestampType", - date = "DateType") + date = "DateType", + tinyint = "ByteType", + smallint = "ShortType", + int = "IntegerType", + bigint = "LongType", + decimal = "DecimalType(10,0)") complexTypes <- list("map" = "MapType(StringType,IntegerType,true)", "array" = "ArrayType(StringType,true)", @@ -174,7 +179,11 @@ test_that("structField type strings", { numeric = "numeric", character = "character", raw = "raw", - logical = "logical") + logical = "logical", + short = "short", + varchar = "varchar", + long = "long", + char = "char") complexErrors <- list("map" = " integer", "array" = "String", diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 6d006eccf665e..1ca383da26ec2 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -167,13 +167,13 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { - method <- "getSQLDataType" + method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, - "unknown"), + "col", "unknown", TRUE), error = function(e) { captureJVMException(e, method) }), - "Error in getSQLDataType : illegal argument - Invalid type unknown") + "parse error - .*DataType unknown.*not supported.") }) test_that("hashCode", { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index a26d00411fbaa..d94e528a3ad47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -31,6 +31,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ @@ -92,48 +93,8 @@ private[sql] object SQLUtils extends Logging { def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*) } - def getSQLDataType(dataType: String): DataType = { - dataType match { - case "byte" => org.apache.spark.sql.types.ByteType - case "integer" => org.apache.spark.sql.types.IntegerType - case "float" => org.apache.spark.sql.types.FloatType - case "double" => org.apache.spark.sql.types.DoubleType - case "numeric" => org.apache.spark.sql.types.DoubleType - case "character" => org.apache.spark.sql.types.StringType - case "string" => org.apache.spark.sql.types.StringType - case "binary" => org.apache.spark.sql.types.BinaryType - case "raw" => org.apache.spark.sql.types.BinaryType - case "logical" => org.apache.spark.sql.types.BooleanType - case "boolean" => org.apache.spark.sql.types.BooleanType - case "timestamp" => org.apache.spark.sql.types.TimestampType - case "date" => org.apache.spark.sql.types.DateType - case r"\Aarray<(.+)${elemType}>\Z" => - org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) - case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" => - if (keyType != "string" && keyType != "character") { - throw new IllegalArgumentException("Key type of a map must be string or character") - } - org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) - case r"\Astruct<(.+)${fieldsStr}>\Z" => - if (fieldsStr(fieldsStr.length - 1) == ',') { - throw new IllegalArgumentException(s"Invalid type $dataType") - } - val fields = fieldsStr.split(",") - val structFields = fields.map { field => - field match { - case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" => - createStructField(fieldName, fieldType, true) - - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") - } - } - createStructType(structFields) - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") - } - } - def createStructField(name: String, dataType: String, nullable: Boolean): StructField = { - val dtObj = getSQLDataType(dataType) + val dtObj = CatalystSqlParser.parseDataType(dataType) StructField(name, dtObj, nullable) }