Skip to content

Commit

Permalink
First version of WriteSupport for nested types
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreSchumacher committed Jun 19, 2014
1 parent d1911dc commit 1dc5ac9
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,15 @@ case class InsertIntoParquetTable(

val job = new Job(sc.hadoopConfiguration)

ParquetOutputFormat.setWriteSupportClass(
job,
classOf[org.apache.spark.sql.parquet.RowWriteSupport])
val writeSupport =
if (child.output.map(_.dataType).forall(_.isPrimitive())) {
logger.info("Initializing MutableRowWriteSupport")
classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport]
} else {
classOf[org.apache.spark.sql.parquet.RowWriteSupport]
}

ParquetOutputFormat.setWriteSupportClass(job, writeSupport)

// TODO: move that to function in object
val conf = ContextUtil.getConfiguration(job)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ import parquet.schema.{MessageType, MessageTypeParser}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, Row}
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.types.ArrayType
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.catalyst.types.MapType

/**
* A `parquet.io.api.RecordMaterializer` for Rows.
Expand Down Expand Up @@ -97,9 +101,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
MessageTypeParser.parseMessageType(configuration.get(RowWriteSupport.PARQUET_ROW_SCHEMA))
}

private var schema: MessageType = null
private var writer: RecordConsumer = null
private var attributes: Seq[Attribute] = null
private[parquet] var schema: MessageType = null
private[parquet] var writer: RecordConsumer = null
private[parquet] var attributes: Seq[Attribute] = null

override def init(configuration: Configuration): WriteSupport.WriteContext = {
schema = if (schema == null) getSchema(configuration) else schema
Expand All @@ -116,7 +120,6 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
log.debug(s"preparing for write with schema $schema")
}

// TODO: add groups (nested fields)
override def write(record: Row): Unit = {
if (attributes.size > record.size) {
throw new IndexOutOfBoundsException(
Expand All @@ -129,13 +132,131 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
// null values indicate optional fields but we do not check currently
if (record(index) != null && record(index) != Nil) {
writer.startField(attributes(index).name, index)
ParquetTypesConverter.consumeType(writer, attributes(index).dataType, record, index)
writeValue(attributes(index).dataType, record(index))
writer.endField(attributes(index).name, index)
}
index = index + 1
}
writer.endMessage()
}

private[parquet] def writeValue(schema: DataType, value: Any): Unit = {
schema match {
case t @ ArrayType(_) => writeArray(t, value.asInstanceOf[Row])
case t @ MapType(_, _) => writeMap(t, value.asInstanceOf[Map[Any, Any]])
case t @ StructType(_) => writeStruct(t, value.asInstanceOf[Row])
case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value)
}
}

private[parquet] def writePrimitive(schema: PrimitiveType, value: Any): Unit = {
schema match {
case StringType => writer.addBinary(
Binary.fromByteArray(
value.asInstanceOf[String].getBytes("utf-8")
)
)
case IntegerType => writer.addInteger(value.asInstanceOf[Int])
case LongType => writer.addLong(value.asInstanceOf[Long])
case DoubleType => writer.addDouble(value.asInstanceOf[Double])
case FloatType => writer.addFloat(value.asInstanceOf[Float])
case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
case _ => sys.error(s"Do not know how to writer $schema to consumer")
}
}

private[parquet] def writeStruct(schema: StructType, struct: Row): Unit = {
val fields = schema.fields.toArray
writer.startGroup()
var i = 0
while(i < fields.size) {
writer.startField(fields(i).name, i)
writeValue(fields(i).dataType, struct(i))
writer.endField(fields(i).name, i)
i = i + 1
}
writer.endGroup()
}

private[parquet] def writeArray(schema: ArrayType, array: Row): Unit = {
val elementType = schema.elementType
writer.startGroup()
if (array.size > 0) {
writer.startField("values", 0)
writer.startGroup()
var i = 0
while(i < array.size) {
writeValue(elementType, array(i))
i = i + 1
}
writer.endGroup()
writer.endField("values", 0)
}
writer.endGroup()
}

private[parquet] def writeMap(schema: MapType, map: Map[_, _]): Unit = {
writer.startGroup()
if (map.size > 0) {
writer.startField("map", 0)
writer.startGroup()
writer.startField("key", 0)
for(key <- map.keys) {
writeValue(schema.keyType, key)
}
writer.endField("key", 0)
writer.startField("value", 1)
for(value <- map.values) {
writeValue(schema.valueType, value)
}
writer.endField("value", 1)
writer.endGroup()
writer.endField("map", 0)
}
writer.endGroup()
}
}

// Optimized for non-nested rows
private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
override def write(record: Row): Unit = {
if (attributes.size > record.size) {
throw new IndexOutOfBoundsException(
s"Trying to write more fields than contained in row (${attributes.size}>${record.size})")
}

var index = 0
writer.startMessage()
while(index < attributes.size) {
// null values indicate optional fields but we do not check currently
if (record(index) != null && record(index) != Nil) {
writer.startField(attributes(index).name, index)
consumeType(attributes(index).dataType, record, index)
writer.endField(attributes(index).name, index)
}
index = index + 1
}
writer.endMessage()
}

private def consumeType(
ctype: DataType,
record: Row,
index: Int): Unit = {
ctype match {
case StringType => writer.addBinary(
Binary.fromByteArray(
record(index).asInstanceOf[String].getBytes("utf-8")
)
)
case IntegerType => writer.addInteger(record.getInt(index))
case LongType => writer.addLong(record.getLong(index))
case DoubleType => writer.addDouble(record.getDouble(index))
case FloatType => writer.addFloat(record.getFloat(index))
case BooleanType => writer.addBoolean(record.getBoolean(index))
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
}
}
}

private[parquet] object RowWriteSupport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,26 +259,6 @@ private[parquet] object ParquetTypesConverter {
}
}

def consumeType(
consumer: RecordConsumer,
ctype: DataType,
record: Row,
index: Int): Unit = {
ctype match {
case StringType => consumer.addBinary(
Binary.fromByteArray(
record(index).asInstanceOf[String].getBytes("utf-8")
)
)
case IntegerType => consumer.addInteger(record.getInt(index))
case LongType => consumer.addLong(record.getLong(index))
case DoubleType => consumer.addDouble(record.getDouble(index))
case FloatType => consumer.addFloat(record.getFloat(index))
case BooleanType => consumer.addBoolean(record.getBoolean(index))
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
}
}

def getSchema(schemaString: String) : MessageType =
MessageTypeParser.parseMessageType(schemaString)

Expand Down

0 comments on commit 1dc5ac9

Please sign in to comment.