Skip to content

Commit

Permalink
[SPARK-9023] [SQL] Followup for #7456 (Efficiency improvements for Un…
Browse files Browse the repository at this point in the history
…safeRows in Exchange)

This patch addresses code review feedback from #7456.

Author: Josh Rosen <[email protected]>

Closes #7551 from JoshRosen/unsafe-exchange-followup and squashes the following commits:

76dbdf8 [Josh Rosen] Add comments + more methods to UnsafeRowSerializer
3d7a1f2 [Josh Rosen] Add writeToStream() method to UnsafeRow
  • Loading branch information
JoshRosen authored and rxin committed Jul 21, 2015
1 parent 67570be commit 48f8fd4
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.catalyst.expressions;

import java.io.IOException;
import java.io.OutputStream;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.unsafe.PlatformDependent;
Expand Down Expand Up @@ -371,6 +374,36 @@ public InternalRow copy() {
}
}

/**
* Write this UnsafeRow's underlying bytes to the given OutputStream.
*
* @param out the stream to write to.
* @param writeBuffer a byte array for buffering chunks of off-heap data while writing to the
* output stream. If this row is backed by an on-heap byte array, then this
* buffer will not be used and may be null.
*/
public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException {
if (baseObject instanceof byte[]) {
int offsetInByteArray = (int) (PlatformDependent.BYTE_ARRAY_OFFSET - baseOffset);
out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes);
} else {
int dataRemaining = sizeInBytes;
long rowReadPosition = baseOffset;
while (dataRemaining > 0) {
int toTransfer = Math.min(writeBuffer.length, dataRemaining);
PlatformDependent.copyMemory(
baseObject,
rowReadPosition,
writeBuffer,
PlatformDependent.BYTE_ARRAY_OFFSET,
toTransfer);
out.write(writeBuffer, 0, toTransfer);
rowReadPosition += toTransfer;
dataRemaining -= toTransfer;
}
}
}

@Override
public boolean anyNull() {
return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,16 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S

private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance {

/**
* Marks the end of a stream written with [[serializeStream()]].
*/
private[this] val EOF: Int = -1

/**
* Serializes a stream of UnsafeRows. Within the stream, each record consists of a record
* length (stored as a 4-byte integer, written high byte first), followed by the record's bytes.
* The end of the stream is denoted by a record with the special length `EOF` (-1).
*/
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
private[this] val dOut: DataOutputStream = new DataOutputStream(out)
Expand All @@ -59,32 +67,31 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
val row = value.asInstanceOf[UnsafeRow]
assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool")
dOut.writeInt(row.getSizeInBytes)
var dataRemaining: Int = row.getSizeInBytes
val baseObject = row.getBaseObject
var rowReadPosition: Long = row.getBaseOffset
while (dataRemaining > 0) {
val toTransfer: Int = Math.min(writeBuffer.length, dataRemaining)
PlatformDependent.copyMemory(
baseObject,
rowReadPosition,
writeBuffer,
PlatformDependent.BYTE_ARRAY_OFFSET,
toTransfer)
out.write(writeBuffer, 0, toTransfer)
rowReadPosition += toTransfer
dataRemaining -= toTransfer
}
row.writeToStream(out, writeBuffer)
this
}

override def writeKey[T: ClassTag](key: T): SerializationStream = {
// The key is only needed on the map side when computing partition ids. It does not need to
// be shuffled.
assert(key.isInstanceOf[Int])
this
}
override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream =

override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = {
// This method is never called by shuffle code.
throw new UnsupportedOperationException
override def writeObject[T: ClassTag](t: T): SerializationStream =
}

override def writeObject[T: ClassTag](t: T): SerializationStream = {
// This method is never called by shuffle code.
throw new UnsupportedOperationException
override def flush(): Unit = dOut.flush()
}

override def flush(): Unit = {
dOut.flush()
}

override def close(): Unit = {
writeBuffer = null
dOut.writeInt(EOF)
Expand All @@ -95,6 +102,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
override def deserializeStream(in: InputStream): DeserializationStream = {
new DeserializationStream {
private[this] val dIn: DataInputStream = new DataInputStream(in)
// 1024 is a default buffer size; this buffer will grow to accommodate larger rows
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
private[this] var row: UnsafeRow = new UnsafeRow()
private[this] var rowTuple: (Int, UnsafeRow) = (0, row)
Expand Down Expand Up @@ -126,14 +134,40 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
}
}
}
override def asIterator: Iterator[Any] = throw new UnsupportedOperationException
override def readKey[T: ClassTag](): T = throw new UnsupportedOperationException
override def readValue[T: ClassTag](): T = throw new UnsupportedOperationException
override def readObject[T: ClassTag](): T = throw new UnsupportedOperationException
override def close(): Unit = dIn.close()

override def asIterator: Iterator[Any] = {
// This method is never called by shuffle code.
throw new UnsupportedOperationException
}

override def readKey[T: ClassTag](): T = {
// We skipped serialization of the key in writeKey(), so just return a dummy value since
// this is going to be discarded anyways.
null.asInstanceOf[T]
}

override def readValue[T: ClassTag](): T = {
val rowSize = dIn.readInt()
if (rowBuffer.length < rowSize) {
rowBuffer = new Array[Byte](rowSize)
}
ByteStreams.readFully(in, rowBuffer, 0, rowSize)
row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null)
row.asInstanceOf[T]
}

override def readObject[T: ClassTag](): T = {
// This method is never called by shuffle code.
throw new UnsupportedOperationException
}

override def close(): Unit = {
dIn.close()
}
}
}

// These methods are never called by shuffle code.
override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException
override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
throw new UnsupportedOperationException
Expand Down
71 changes: 71 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.io.ByteArrayOutputStream

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
import org.apache.spark.sql.types.{IntegerType, StringType}
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.memory.MemoryAllocator
import org.apache.spark.unsafe.types.UTF8String

class UnsafeRowSuite extends SparkFunSuite {
test("writeToStream") {
val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123)
val arrayBackedUnsafeRow: UnsafeRow =
UnsafeProjection.create(Seq(StringType, StringType, IntegerType)).apply(row)
assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]])
val bytesFromArrayBackedRow: Array[Byte] = {
val baos = new ByteArrayOutputStream()
arrayBackedUnsafeRow.writeToStream(baos, null)
baos.toByteArray
}
val bytesFromOffheapRow: Array[Byte] = {
val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes)
try {
PlatformDependent.copyMemory(
arrayBackedUnsafeRow.getBaseObject,
arrayBackedUnsafeRow.getBaseOffset,
offheapRowPage.getBaseObject,
offheapRowPage.getBaseOffset,
arrayBackedUnsafeRow.getSizeInBytes
)
val offheapUnsafeRow: UnsafeRow = new UnsafeRow()
offheapUnsafeRow.pointTo(
offheapRowPage.getBaseObject,
offheapRowPage.getBaseOffset,
3, // num fields
arrayBackedUnsafeRow.getSizeInBytes,
null // object pool
)
assert(offheapUnsafeRow.getBaseObject === null)
val baos = new ByteArrayOutputStream()
val writeBuffer = new Array[Byte](1024)
offheapUnsafeRow.writeToStream(baos, writeBuffer)
baos.toByteArray
} finally {
MemoryAllocator.UNSAFE.free(offheapRowPage)
}
}

assert(bytesFromArrayBackedRow === bytesFromOffheapRow)
}
}

0 comments on commit 48f8fd4

Please sign in to comment.