From a25c1ab8f04a4e19d82ff4c18a0b1689d8b3ddac Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 21 May 2015 09:58:47 -0700 Subject: [PATCH] [SPARK-7565] [SQL] fix MapType in JsonRDD The key of Map in JsonRDD should be converted into UTF8String (also failed records), Thanks to yhuai viirya Closes #6084 Author: Davies Liu Closes #6299 from davies/string_in_json and squashes the following commits: 0dbf559 [Davies Liu] improve test, fix corrupt record 6836a80 [Davies Liu] move unit tests into Scala b97af11 [Davies Liu] fix MapType in JsonRDD --- .../apache/spark/sql/json/JacksonParser.scala | 8 +++--- .../org/apache/spark/sql/json/JsonRDD.scala | 16 +++++++---- .../org/apache/spark/sql/json/JsonSuite.scala | 28 ++++++++++++++++++- 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 81611513582a8..0e223758051a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -150,10 +150,10 @@ private[sql] object JacksonParser { private def convertMap( factory: JsonFactory, parser: JsonParser, - valueType: DataType): Map[String, Any] = { - val builder = Map.newBuilder[String, Any] + valueType: DataType): Map[UTF8String, Any] = { + val builder = Map.newBuilder[UTF8String, Any] while (nextUntil(parser, JsonToken.END_OBJECT)) { - builder += parser.getCurrentName -> convertField(factory, parser, valueType) + builder += UTF8String(parser.getCurrentName) -> convertField(factory, parser, valueType) } builder.result() @@ -181,7 +181,7 @@ private[sql] object JacksonParser { val row = new GenericMutableRow(schema.length) for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) { require(schema(corruptIndex).dataType == StringType) - row.update(corruptIndex, record) + row.update(corruptIndex, UTF8String(record)) } Seq(row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 4c32710a17bc7..037a6d60a2ed6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -20,18 +20,18 @@ package org.apache.spark.sql.json import java.sql.Timestamp import scala.collection.Map -import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} +import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} -import com.fasterxml.jackson.core.{JsonGenerator, JsonProcessingException} +import com.fasterxml.jackson.core.JsonProcessingException import com.fasterxml.jackson.databind.ObjectMapper +import org.apache.spark.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ -import org.apache.spark.Logging private[sql] object JsonRDD extends Logging { @@ -318,7 +318,8 @@ private[sql] object JsonRDD extends Logging { parsed } catch { - case e: JsonProcessingException => Map(columnNameOfCorruptRecords -> record) :: Nil + case e: JsonProcessingException => + Map(columnNameOfCorruptRecords -> UTF8String(record)) :: Nil } } }) @@ -422,7 +423,10 @@ private[sql] object JsonRDD extends Logging { value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case MapType(StringType, valueType, _) => val map = value.asInstanceOf[Map[String, Any]] - map.mapValues(enforceCorrectType(_, valueType)).map(identity) + map.map { + case (k, v) => + (UTF8String(k), enforceCorrectType(v, valueType)) + }.map(identity) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) case DateType => toDate(value) case TimestampType => toTimestamp(value) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 6f747e5846f74..7e6eeba17752a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -25,7 +25,6 @@ import org.scalactic.Tolerance._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.functions._ import org.apache.spark.sql.json.InferSchema.compatibleType import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.TestSQLContext @@ -1074,4 +1073,31 @@ class JsonSuite extends QueryTest { assert(StructType(Seq()) === emptySchema) } + test("SPARK-7565 MapType in JsonRDD") { + val useStreaming = getConf(SQLConf.USE_JACKSON_STREAMING_API, "true") + val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + try{ + for (useStreaming <- List("true", "false")) { + setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + val temp = Utils.createTempDir().getPath + + val df = read.schema(schemaWithSimpleMap).json(mapType1) + df.write.mode("overwrite").parquet(temp) + // order of MapType is not defined + assert(read.parquet(temp).count() == 5) + + val df2 = read.json(corruptRecords) + df2.write.mode("overwrite").parquet(temp) + checkAnswer(read.parquet(temp), df2.collect()) + } + } finally { + setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + } + } + }