diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 4ac057d0f2d83..1c58fd96d750a 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -41,10 +41,7 @@ infer_type <- function(x) { if (type == "map") { stopifnot(length(x) > 0) key <- ls(x)[[1]] - list(type = "map", - keyType = "string", - valueType = infer_type(get(key, x)), - valueContainsNull = TRUE) + paste0("map") } else if (type == "array") { stopifnot(length(x) > 0) names <- names(x) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d1858ec227b56..ce88d0b071b72 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -50,6 +50,7 @@ readTypedObject <- function(con, type) { "t" = readTime(con), "a" = readArray(con), "l" = readList(con), + "e" = readEnv(con), "n" = NULL, "j" = getJobj(readString(con)), stop(paste("Unsupported type for deserialization", type))) @@ -121,6 +122,19 @@ readList <- function(con) { } } +readEnv <- function(con) { + env <- new.env() + len <- readInt(con) + if (len > 0) { + for (i in 1:len) { + key <- readString(con) + value <- readObject(con) + env[[key]] <- value + } + } + env +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 62d4f73878d29..8df1563f8ebc0 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -131,13 +131,33 @@ checkType <- function(type) { if (type %in% primtiveTypes) { return() } else { - m <- regexec("^array<(.*)>$", type) - matchedStrings <- regmatches(type, m) - if (length(matchedStrings[[1]]) >= 2) { - elemType <- matchedStrings[[1]][2] - checkType(elemType) - return() - } + # Check complex types + firstChar <- substr(type, 1, 1) + switch (firstChar, + a = { + # Array type + m <- regexec("^array<(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + elemType <- matchedStrings[[1]][2] + checkType(elemType) + return() + } + }, + m = { + # Map type + m <- regexec("^map<(.*),(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 3) { + keyType <- matchedStrings[[1]][2] + if (keyType != "string" && keyType != "character") { + stop("Key type in a map must be string or character") + } + valueType <- matchedStrings[[1]][3] + checkType(valueType) + return() + } + }) } stop(paste("Unsupported type for Dataframe:", type)) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 98d4402d368e1..e159a69584274 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -57,7 +57,7 @@ mockLinesComplexType <- complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) -test_that("infer types", { +test_that("infer types and check types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") expect_equal(infer_type("abc"), "string") @@ -72,9 +72,9 @@ test_that("infer types", { checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() assign("a", 1L, envir = e) - expect_equal(infer_type(e), - list(type = "map", keyType = "string", valueType = "integer", - valueContainsNull = TRUE)) + expect_equal(infer_type(e), "map") + + expect_error(checkType("map"), "Key type in a map must be string or character") }) test_that("structType and structField", { @@ -242,7 +242,7 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -test_that("create DataFrame with nested array and struct", { +test_that("create DataFrame with nested array and map", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) @@ -253,21 +253,35 @@ test_that("create DataFrame with nested array and struct", { # ldf <- collect(df) # expect_equal(ldf[1,], l[[1]]) + # ArrayType and MapType + e <- new.env() + assign("n", 3L, envir = e) - # ArrayType only for now - l <- list(as.list(1:10), list("a", "b")) - df <- createDataFrame(sqlContext, list(l), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"))) + l <- list(as.list(1:10), list("a", "b"), e) + df <- createDataFrame(sqlContext, list(l), c("a", "b", "c")) + expect_equal(dtypes(df), list(c("a", "array"), + c("b", "array"), + c("c", "map"))) expect_equal(count(df), 1) ldf <- collect(df) - expect_equal(names(ldf), c("a", "b")) + expect_equal(names(ldf), c("a", "b", "c")) expect_equal(ldf[1, 1][[1]], l[[1]]) expect_equal(ldf[1, 2][[1]], l[[2]]) + e <- ldf$c[[1]] + expect_equal(class(e), "environment") + expect_equal(ls(e), "n") + expect_equal(e$n, 3L) }) +# For test map type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + test_that("Collect DataFrame with complex types", { - # only ArrayType now - # TODO: tests for StructType and MapType after they are supported + # ArrayType df <- jsonFile(sqlContext, complexTypeJsonPath) ldf <- collect(df) @@ -277,6 +291,24 @@ test_that("Collect DataFrame with complex types", { expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + + # MapType + schema <- structType(structField("name", "string"), + structField("info", "map")) + df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) + expect_equal(dtypes(df), list(c("name", "string"), + c("info", "map"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("name", "info")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "environment") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) + + # TODO: tests for StructType after it is supported }) test_that("jsonFile() on a local file returns a DataFrame", { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 3c92bb7a1c73c..0c78613e406e1 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -209,11 +209,23 @@ private[spark] object SerDe { case "array" => dos.writeByte('a') // Array of objects case "list" => dos.writeByte('l') + case "map" => dos.writeByte('e') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") } } + private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + if (key == null) { + throw new IllegalArgumentException("Key in map can't be null.") + } else if (!key.isInstanceOf[String]) { + throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") + } + + writeString(dos, key.asInstanceOf[String]) + writeObject(dos, value) + } + def writeObject(dos: DataOutputStream, obj: Object): Unit = { if (obj == null) { writeType(dos, "void") @@ -306,6 +318,25 @@ private[spark] object SerDe { writeInt(dos, v.length) v.foreach(elem => writeObject(dos, elem)) + // Handle map + case v: java.util.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + val iter = v.entrySet.iterator + while(iter.hasNext) { + val entry = iter.next + val key = entry.getKey + val value = entry.getValue + + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case v: scala.collection.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + v.foreach { case (key, value) => + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case _ => writeType(dos, "jobj") writeJObj(dos, value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d4b834adb6e39..f45d119c8cfdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -64,6 +64,12 @@ private[r] object SQLUtils { case r"\Aarray<(.*)${elemType}>\Z" => { org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) } + case r"\Amap<(.*)${keyType},(.*)${valueType}>\Z" => { + if (keyType != "string" && keyType != "character") { + throw new IllegalArgumentException("Key type of a map must be string or character") + } + org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) + } case _ => throw new IllegalArgumentException(s"Invaid type $dataType") } }