diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 8aeb3e9c4aad9..318550e5ed899 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -49,8 +49,16 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance { + /** + * Marks the end of a stream written with [[serializeStream()]]. + */ private[this] val EOF: Int = -1 + /** + * Serializes a stream of UnsafeRows. Within the stream, each record consists of a record + * length (stored as a 4-byte integer, written high byte first), followed by the record's bytes. + * The end of the stream is denoted by a record with the special length `EOF` (-1). + */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) private[this] val dOut: DataOutputStream = new DataOutputStream(out) @@ -62,15 +70,28 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst row.writeToStream(out, writeBuffer) this } + override def writeKey[T: ClassTag](key: T): SerializationStream = { + // The key is only needed on the map side when computing partition ids. It does not need to + // be shuffled. assert(key.isInstanceOf[Int]) this } - override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = + + override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { + // This method is never called by shuffle code. throw new UnsupportedOperationException - override def writeObject[T: ClassTag](t: T): SerializationStream = + } + + override def writeObject[T: ClassTag](t: T): SerializationStream = { + // This method is never called by shuffle code. throw new UnsupportedOperationException - override def flush(): Unit = dOut.flush() + } + + override def flush(): Unit = { + dOut.flush() + } + override def close(): Unit = { writeBuffer = null dOut.writeInt(EOF) @@ -81,6 +102,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def deserializeStream(in: InputStream): DeserializationStream = { new DeserializationStream { private[this] val dIn: DataInputStream = new DataInputStream(in) + // 1024 is a default buffer size; this buffer will grow to accommodate larger rows private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) private[this] var row: UnsafeRow = new UnsafeRow() private[this] var rowTuple: (Int, UnsafeRow) = (0, row) @@ -112,14 +134,40 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst } } } - override def asIterator: Iterator[Any] = throw new UnsupportedOperationException - override def readKey[T: ClassTag](): T = throw new UnsupportedOperationException - override def readValue[T: ClassTag](): T = throw new UnsupportedOperationException - override def readObject[T: ClassTag](): T = throw new UnsupportedOperationException - override def close(): Unit = dIn.close() + + override def asIterator: Iterator[Any] = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def readKey[T: ClassTag](): T = { + // We skipped serialization of the key in writeKey(), so just return a dummy value since + // this is going to be discarded anyways. + null.asInstanceOf[T] + } + + override def readValue[T: ClassTag](): T = { + val rowSize = dIn.readInt() + if (rowBuffer.length < rowSize) { + rowBuffer = new Array[Byte](rowSize) + } + ByteStreams.readFully(in, rowBuffer, 0, rowSize) + row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null) + row.asInstanceOf[T] + } + + override def readObject[T: ClassTag](): T = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def close(): Unit = { + dIn.close() + } } } + // These methods are never called by shuffle code. override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException override def deserialize[T: ClassTag](bytes: ByteBuffer): T = throw new UnsupportedOperationException