Skip to content

Commit

Permalink
[SPARK-20493][R] De-duplicate parse logics for DDL-like type strings …
Browse files Browse the repository at this point in the history
…in R

## What changes were proposed in this pull request?

It seems we are using `SQLUtils.getSQLDataType` for type string in structField. It looks we can replace this with `CatalystSqlParser.parseDataType`.

They look similar DDL-like type definitions as below:

```scala
scala> Seq(Tuple1(Tuple1("a"))).toDF.show()
```
```
+---+
| _1|
+---+
|[a]|
+---+
```

```scala
scala> Seq(Tuple1(Tuple1("a"))).toDF.select($"_1".cast("struct<_1:string>")).show()
```
```
+---+
| _1|
+---+
|[a]|
+---+
```

Such type strings looks identical when R’s one as below:

```R
> write.df(sql("SELECT named_struct('_1', 'a') as struct"), "/tmp/aa", "parquet")
> collect(read.df("/tmp/aa", "parquet", structType(structField("struct", "struct<_1:string>"))))
  struct
1      a
```

R’s one is stricter because we are checking the types via regular expressions in R side ahead.

Actual logics there look a bit different but as we check it ahead in R side, it looks replacing it would not introduce (I think) no behaviour changes. To make this sure, the tests dedicated for it were added in SPARK-20105. (It looks `structField` is the only place that calls this method).

## How was this patch tested?

Existing tests - https://github.com/apache/spark/blob/master/R/pkg/inst/tests/testthat/test_sparkSQL.R#L143-L194 should cover this.

Author: hyukjinkwon <[email protected]>

Closes #17785 from HyukjinKwon/SPARK-20493.
  • Loading branch information
HyukjinKwon authored and Felix Cheung committed Apr 29, 2017
1 parent ee694cd commit 70f1bcd
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 46 deletions.
8 changes: 8 additions & 0 deletions R/pkg/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,14 @@ captureJVMException <- function(e, method) {
# Extract the first message of JVM exception.
first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
stop(paste0(rmsg, "no such table - ", first), call. = FALSE)
} else if (any(grep("org.apache.spark.sql.catalyst.parser.ParseException: ", stacktrace))) {
msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.parser.ParseException: ",
fixed = TRUE)[[1]]
# Extract "Error in ..." message.
rmsg <- msg[1]
# Extract the first message of JVM exception.
first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
stop(paste0(rmsg, "parse error - ", first), call. = FALSE)
} else {
stop(stacktrace, call. = FALSE)
}
Expand Down
13 changes: 11 additions & 2 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,12 @@ test_that("structField type strings", {
binary = "BinaryType",
boolean = "BooleanType",
timestamp = "TimestampType",
date = "DateType")
date = "DateType",
tinyint = "ByteType",
smallint = "ShortType",
int = "IntegerType",
bigint = "LongType",
decimal = "DecimalType(10,0)")

complexTypes <- list("map<string,integer>" = "MapType(StringType,IntegerType,true)",
"array<string>" = "ArrayType(StringType,true)",
Expand All @@ -174,7 +179,11 @@ test_that("structField type strings", {
numeric = "numeric",
character = "character",
raw = "raw",
logical = "logical")
logical = "logical",
short = "short",
varchar = "varchar",
long = "long",
char = "char")

complexErrors <- list("map<string, integer>" = " integer",
"array<String>" = "String",
Expand Down
6 changes: 3 additions & 3 deletions R/pkg/inst/tests/testthat/test_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ test_that("convertToJSaveMode", {
})

test_that("captureJVMException", {
method <- "getSQLDataType"
method <- "createStructField"
expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method,
"unknown"),
"col", "unknown", TRUE),
error = function(e) {
captureJVMException(e, method)
}),
"Error in getSQLDataType : illegal argument - Invalid type unknown")
"parse error - .*DataType unknown.*not supported.")
})

test_that("hashCode", {
Expand Down
43 changes: 2 additions & 41 deletions sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.execution.command.ShowTablesCommand
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -92,48 +93,8 @@ private[sql] object SQLUtils extends Logging {
def r: Regex = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*)
}

def getSQLDataType(dataType: String): DataType = {
dataType match {
case "byte" => org.apache.spark.sql.types.ByteType
case "integer" => org.apache.spark.sql.types.IntegerType
case "float" => org.apache.spark.sql.types.FloatType
case "double" => org.apache.spark.sql.types.DoubleType
case "numeric" => org.apache.spark.sql.types.DoubleType
case "character" => org.apache.spark.sql.types.StringType
case "string" => org.apache.spark.sql.types.StringType
case "binary" => org.apache.spark.sql.types.BinaryType
case "raw" => org.apache.spark.sql.types.BinaryType
case "logical" => org.apache.spark.sql.types.BooleanType
case "boolean" => org.apache.spark.sql.types.BooleanType
case "timestamp" => org.apache.spark.sql.types.TimestampType
case "date" => org.apache.spark.sql.types.DateType
case r"\Aarray<(.+)${elemType}>\Z" =>
org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType))
case r"\Amap<(.+)${keyType},(.+)${valueType}>\Z" =>
if (keyType != "string" && keyType != "character") {
throw new IllegalArgumentException("Key type of a map must be string or character")
}
org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType))
case r"\Astruct<(.+)${fieldsStr}>\Z" =>
if (fieldsStr(fieldsStr.length - 1) == ',') {
throw new IllegalArgumentException(s"Invalid type $dataType")
}
val fields = fieldsStr.split(",")
val structFields = fields.map { field =>
field match {
case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" =>
createStructField(fieldName, fieldType, true)

case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
}
}
createStructType(structFields)
case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
}
}

def createStructField(name: String, dataType: String, nullable: Boolean): StructField = {
val dtObj = getSQLDataType(dataType)
val dtObj = CatalystSqlParser.parseDataType(dataType)
StructField(name, dtObj, nullable)
}

Expand Down

0 comments on commit 70f1bcd

Please sign in to comment.