Skip to content

Commit

Permalink
Fixing nested WriteSupport and adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreSchumacher committed Jun 19, 2014
1 parent 1dc5ac9 commit e99cc51
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,60 +141,67 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
}

private[parquet] def writeValue(schema: DataType, value: Any): Unit = {
schema match {
case t @ ArrayType(_) => writeArray(t, value.asInstanceOf[Row])
case t @ MapType(_, _) => writeMap(t, value.asInstanceOf[Map[Any, Any]])
case t @ StructType(_) => writeStruct(t, value.asInstanceOf[Row])
case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value)
if (value != null && value != Nil) {
schema match {
case t @ ArrayType(_) => writeArray(t, value.asInstanceOf[Row])
case t @ MapType(_, _) => writeMap(t, value.asInstanceOf[Map[Any, Any]])
case t @ StructType(_) => writeStruct(t, value.asInstanceOf[Row])
case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value)
}
}
}

private[parquet] def writePrimitive(schema: PrimitiveType, value: Any): Unit = {
schema match {
case StringType => writer.addBinary(
Binary.fromByteArray(
value.asInstanceOf[String].getBytes("utf-8")
if (value != null && value != Nil) {
schema match {
case StringType => writer.addBinary(
Binary.fromByteArray(
value.asInstanceOf[String].getBytes("utf-8")
)
)
)
case IntegerType => writer.addInteger(value.asInstanceOf[Int])
case LongType => writer.addLong(value.asInstanceOf[Long])
case DoubleType => writer.addDouble(value.asInstanceOf[Double])
case FloatType => writer.addFloat(value.asInstanceOf[Float])
case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
case _ => sys.error(s"Do not know how to writer $schema to consumer")
case IntegerType => writer.addInteger(value.asInstanceOf[Int])
case LongType => writer.addLong(value.asInstanceOf[Long])
case DoubleType => writer.addDouble(value.asInstanceOf[Double])
case FloatType => writer.addFloat(value.asInstanceOf[Float])
case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
case _ => sys.error(s"Do not know how to writer $schema to consumer")
}
}
}

private[parquet] def writeStruct(schema: StructType, struct: Row): Unit = {
val fields = schema.fields.toArray
writer.startGroup()
var i = 0
while(i < fields.size) {
writer.startField(fields(i).name, i)
writeValue(fields(i).dataType, struct(i))
writer.endField(fields(i).name, i)
i = i + 1
if (struct != null && struct != Nil) {
val fields = schema.fields.toArray
writer.startGroup()
var i = 0
while(i < fields.size) {
if (struct(i) != null && struct(i) != Nil) {
writer.startField(fields(i).name, i)
writeValue(fields(i).dataType, struct(i))
writer.endField(fields(i).name, i)
}
i = i + 1
}
writer.endGroup()
}
writer.endGroup()
}

private[parquet] def writeArray(schema: ArrayType, array: Row): Unit = {
val elementType = schema.elementType
writer.startGroup()
if (array.size > 0) {
writer.startField("values", 0)
writer.startGroup()
var i = 0
while(i < array.size) {
writeValue(elementType, array(i))
i = i + 1
}
writer.endGroup()
writer.endField("values", 0)
}
writer.endGroup()
}

// TODO: this does not allow null values! Should these be supported?
private[parquet] def writeMap(schema: MapType, map: Map[_, _]): Unit = {
writer.startGroup()
if (map.size > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,13 @@ private[sql] object ParquetTestData {
|optional group data1 {
|repeated group map {
|required binary key;
|optional int32 value;
|required int32 value;
|}
|}
|required group data2 {
|repeated group map {
|required binary key;
|optional group value {
|required group value {
|required int64 payload1;
|optional binary payload2;
|}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,14 @@ private[parquet] object ParquetTypesConverter {
* <ul>
* <li> Primitive types are converter to the corresponding primitive type.</li>
* <li> Group types that have a single field that is itself a group, which has repetition
* level `REPEATED` are treated as follows:<ul>
* <li> If the nested group has name `values` and repetition level `REPEATED`, the
* surrounding group is converted into an [[ArrayType]] with the
* corresponding field type (primitive or complex) as element type.</li>
* <li> If the nested group has name `map`, repetition level `REPEATED` and two fields
* (named `key` and `value`), the surrounding group is converted into a [[MapType]]
* with the corresponding key and value (value possibly complex) types.</li>
* level `REPEATED`, are treated as follows:<ul>
* <li> If the nested group has name `values`, the surrounding group is converted
* into an [[ArrayType]] with the corresponding field type (primitive or
* complex) as element type.</li>
* <li> If the nested group has name `map` and two fields (named `key` and `value`),
* the surrounding group is converted into a [[MapType]]
* with the corresponding key and value (value possibly complex) types.
* Note that we currently assume map values are not nullable.</li>
* <li> Other group types are converted into a [[StructType]] with the corresponding
* field types.</li></ul></li>
* </ul>
Expand Down Expand Up @@ -121,15 +122,19 @@ private[parquet] object ParquetTypesConverter {
keyValueGroup.getFieldCount == 2,
"Parquet Map type malformatted: nested group should have 2 (key, value) fields!")
val keyType = toDataType(keyValueGroup.getFields.apply(0))
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
val valueType = toDataType(keyValueGroup.getFields.apply(1))
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
new MapType(keyType, valueType)
}
case _ => {
// Note: the order of these checks is important!
if (correspondsToMap(groupType)) { // MapType
val keyValueGroup = groupType.getFields.apply(0).asGroupType()
val keyType = toDataType(keyValueGroup.getFields.apply(0))
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
val valueType = toDataType(keyValueGroup.getFields.apply(1))
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
new MapType(keyType, valueType)
} else if (correspondsToArray(groupType)) { // ArrayType
val elementType = toDataType(groupType.getFields.apply(0))
Expand Down Expand Up @@ -240,13 +245,13 @@ private[parquet] object ParquetTypesConverter {
fromDataType(
keyType,
CatalystConverter.MAP_KEY_SCHEMA_NAME,
false,
nullable = false,
inArray = false)
val parquetValueType =
fromDataType(
valueType,
CatalystConverter.MAP_VALUE_SCHEMA_NAME,
true,
nullable = false,
inArray = false)
ConversionPatterns.mapType(
repetition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,55 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(result2(0)(1) === "the answer")
}

test("Writing out Addressbook and reading it back in") {
implicit def anyToRow(value: Any): Row = value.asInstanceOf[Row]
val tmpdir = Utils.createTempDir()
val result = TestSQLContext
.parquetFile(ParquetTestData.testNestedDir1.toString)
.toSchemaRDD
result.saveAsParquetFile(tmpdir.toString)
TestSQLContext
.parquetFile(tmpdir.toString)
.toSchemaRDD
.registerAsTable("tmpcopy")
val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect()
assert(tmpdata.size === 2)
assert(tmpdata(0).size === 2)
assert(tmpdata(0)(0) === "Julien Le Dem")
assert(tmpdata(0)(1) === "Chris Aniszczyk")
assert(tmpdata(1)(0) === "A. Nonymous")
assert(tmpdata(1)(1) === null)
Utils.deleteRecursively(tmpdir)
}

test("Writing out Map and reading it back in") {
implicit def anyToMap(value: Any) = value.asInstanceOf[Map[String, Row]]
val data = TestSQLContext
.parquetFile(ParquetTestData.testNestedDir4.toString)
.toSchemaRDD
val tmpdir = Utils.createTempDir()
data.saveAsParquetFile(tmpdir.toString)
TestSQLContext
.parquetFile(tmpdir.toString)
.toSchemaRDD
.registerAsTable("tmpmapcopy")
val result1 = sql("SELECT data2 FROM tmpmapcopy").collect()
assert(result1.size === 1)
val entry1 = result1(0)(0).getOrElse("7", null)
assert(entry1 != null)
assert(entry1(0) === 42)
assert(entry1(1) === "the answer")
val entry2 = result1(0)(0).getOrElse("8", null)
assert(entry2 != null)
assert(entry2(0) === 49)
assert(entry2(1) === null)
val result2 = sql("SELECT data2[7].payload1, data2[7].payload2 FROM tmpmapcopy").collect()
assert(result2.size === 1)
assert(result2(0)(0) === 42.toLong)
assert(result2(0)(1) === "the answer")
Utils.deleteRecursively(tmpdir)
}

/**
* Creates an empty SchemaRDD backed by a ParquetRelation.
*
Expand Down

0 comments on commit e99cc51

Please sign in to comment.