Skip to content

Commit

Permalink
De-deuplicate parse logics for DDL-like type string in R
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Apr 27, 2017
1 parent ba76662 commit 8874de1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 44 deletions.
9 changes: 9 additions & 0 deletions R/pkg/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,15 @@ captureJVMException <- function(e, method) {
# Extract the first message of JVM exception.
first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
stop(paste0(rmsg, "no such table - ", first), call. = FALSE)
} else
if (any(grep("org.apache.spark.sql.catalyst.parser.ParseException: ", stacktrace))) {
msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.parser.ParseException: ",
fixed = TRUE)[[1]]
# Extract "Error in ..." message.
rmsg <- msg[1]
# Extract the first message of JVM exception.
first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
stop(paste0(rmsg, "parse error - ", first), call. = FALSE)
} else {
stop(stacktrace, call. = FALSE)
}
Expand Down
6 changes: 3 additions & 3 deletions R/pkg/inst/tests/testthat/test_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ test_that("convertToJSaveMode", {
})

test_that("captureJVMException", {
method <- "getSQLDataType"
method <- "createStructField"
expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method,
"unknown"),
"col", "unknown", TRUE),
error = function(e) {
captureJVMException(e, method)
}),
"Error in getSQLDataType : illegal argument - Invalid type unknown")
"parse error - .*DataType unknown.*not supported.")
})

test_that("hashCode", {
Expand Down
43 changes: 2 additions & 41 deletions sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.execution.command.ShowTablesCommand
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -92,48 +93,8 @@ private[sql] object SQLUtils extends Logging {
def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*)
}

def getSQLDataType(dataType: String): DataType = {
dataType match {
case "byte" => org.apache.spark.sql.types.ByteType
case "integer" => org.apache.spark.sql.types.IntegerType
case "float" => org.apache.spark.sql.types.FloatType
case "double" => org.apache.spark.sql.types.DoubleType
case "numeric" => org.apache.spark.sql.types.DoubleType
case "character" => org.apache.spark.sql.types.StringType
case "string" => org.apache.spark.sql.types.StringType
case "binary" => org.apache.spark.sql.types.BinaryType
case "raw" => org.apache.spark.sql.types.BinaryType
case "logical" => org.apache.spark.sql.types.BooleanType
case "boolean" => org.apache.spark.sql.types.BooleanType
case "timestamp" => org.apache.spark.sql.types.TimestampType
case "date" => org.apache.spark.sql.types.DateType
case r"\Aarray<(.+)${elemType}>\Z" =>
org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType))
case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" =>
if (keyType != "string" && keyType != "character") {
throw new IllegalArgumentException("Key type of a map must be string or character")
}
org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType))
case r"\Astruct<(.+)${fieldsStr}>\Z" =>
if (fieldsStr(fieldsStr.length - 1) == ',') {
throw new IllegalArgumentException(s"Invalid type $dataType")
}
val fields = fieldsStr.split(",")
val structFields = fields.map { field =>
field match {
case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" =>
createStructField(fieldName, fieldType, true)

case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
}
}
createStructType(structFields)
case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
}
}

def createStructField(name: String, dataType: String, nullable: Boolean): StructField = {
val dtObj = getSQLDataType(dataType)
val dtObj = CatalystSqlParser.parseDataType(dataType)
StructField(name, dtObj, nullable)
}

Expand Down

0 comments on commit 8874de1

Please sign in to comment.