From 70f1bcd7bcd42b30eabcf06a9639363f1ca4b449 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 29 Apr 2017 11:02:17 -0700 Subject: [PATCH] [SPARK-20493][R] De-duplicate parse logics for DDL-like type strings in R MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? It seems we are using `SQLUtils.getSQLDataType` for type string in structField. It looks we can replace this with `CatalystSqlParser.parseDataType`. They look similar DDL-like type definitions as below: ```scala scala> Seq(Tuple1(Tuple1("a"))).toDF.show() ``` ``` +---+ | _1| +---+ |[a]| +---+ ``` ```scala scala> Seq(Tuple1(Tuple1("a"))).toDF.select($"_1".cast("struct<_1:string>")).show() ``` ``` +---+ | _1| +---+ |[a]| +---+ ``` Such type strings looks identical when R’s one as below: ```R > write.df(sql("SELECT named_struct('_1', 'a') as struct"), "/tmp/aa", "parquet") > collect(read.df("/tmp/aa", "parquet", structType(structField("struct", "struct<_1:string>")))) struct 1 a ``` R’s one is stricter because we are checking the types via regular expressions in R side ahead. Actual logics there look a bit different but as we check it ahead in R side, it looks replacing it would not introduce (I think) no behaviour changes. To make this sure, the tests dedicated for it were added in SPARK-20105. (It looks `structField` is the only place that calls this method). ## How was this patch tested? Existing tests - https://github.com/apache/spark/blob/master/R/pkg/inst/tests/testthat/test_sparkSQL.R#L143-L194 should cover this. Author: hyukjinkwon Closes #17785 from HyukjinKwon/SPARK-20493. --- R/pkg/R/utils.R | 8 ++++ R/pkg/inst/tests/testthat/test_sparkSQL.R | 13 +++++- R/pkg/inst/tests/testthat/test_utils.R | 6 +-- .../org/apache/spark/sql/api/r/SQLUtils.scala | 43 +------------------ 4 files changed, 24 insertions(+), 46 deletions(-) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index fbc89e98847bf..d29af00affb98 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -864,6 +864,14 @@ captureJVMException <- function(e, method) { # Extract the first message of JVM exception. first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] stop(paste0(rmsg, "no such table - ", first), call. = FALSE) + } else if (any(grep("org.apache.spark.sql.catalyst.parser.ParseException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.parser.ParseException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "parse error - ", first), call. = FALSE) } else { stop(stacktrace, call. = FALSE) } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 2cef7191d4f2a..1a3d6df437d7e 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -150,7 +150,12 @@ test_that("structField type strings", { binary = "BinaryType", boolean = "BooleanType", timestamp = "TimestampType", - date = "DateType") + date = "DateType", + tinyint = "ByteType", + smallint = "ShortType", + int = "IntegerType", + bigint = "LongType", + decimal = "DecimalType(10,0)") complexTypes <- list("map" = "MapType(StringType,IntegerType,true)", "array" = "ArrayType(StringType,true)", @@ -174,7 +179,11 @@ test_that("structField type strings", { numeric = "numeric", character = "character", raw = "raw", - logical = "logical") + logical = "logical", + short = "short", + varchar = "varchar", + long = "long", + char = "char") complexErrors <- list("map" = " integer", "array" = "String", diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 6d006eccf665e..1ca383da26ec2 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -167,13 +167,13 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { - method <- "getSQLDataType" + method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, - "unknown"), + "col", "unknown", TRUE), error = function(e) { captureJVMException(e, method) }), - "Error in getSQLDataType : illegal argument - Invalid type unknown") + "parse error - .*DataType unknown.*not supported.") }) test_that("hashCode", { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index a26d00411fbaa..d94e528a3ad47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -31,6 +31,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ @@ -92,48 +93,8 @@ private[sql] object SQLUtils extends Logging { def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*) } - def getSQLDataType(dataType: String): DataType = { - dataType match { - case "byte" => org.apache.spark.sql.types.ByteType - case "integer" => org.apache.spark.sql.types.IntegerType - case "float" => org.apache.spark.sql.types.FloatType - case "double" => org.apache.spark.sql.types.DoubleType - case "numeric" => org.apache.spark.sql.types.DoubleType - case "character" => org.apache.spark.sql.types.StringType - case "string" => org.apache.spark.sql.types.StringType - case "binary" => org.apache.spark.sql.types.BinaryType - case "raw" => org.apache.spark.sql.types.BinaryType - case "logical" => org.apache.spark.sql.types.BooleanType - case "boolean" => org.apache.spark.sql.types.BooleanType - case "timestamp" => org.apache.spark.sql.types.TimestampType - case "date" => org.apache.spark.sql.types.DateType - case r"\Aarray<(.+)${elemType}>\Z" => - org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) - case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" => - if (keyType != "string" && keyType != "character") { - throw new IllegalArgumentException("Key type of a map must be string or character") - } - org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) - case r"\Astruct<(.+)${fieldsStr}>\Z" => - if (fieldsStr(fieldsStr.length - 1) == ',') { - throw new IllegalArgumentException(s"Invalid type $dataType") - } - val fields = fieldsStr.split(",") - val structFields = fields.map { field => - field match { - case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" => - createStructField(fieldName, fieldType, true) - - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") - } - } - createStructType(structFields) - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") - } - } - def createStructField(name: String, dataType: String, nullable: Boolean): StructField = { - val dtObj = getSQLDataType(dataType) + val dtObj = CatalystSqlParser.parseDataType(dataType) StructField(name, dtObj, nullable) }