diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 8cd9e7bc60a03..6ce03a48e9538 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.expressions; +import java.io.IOException; +import java.io.OutputStream; + import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; @@ -371,6 +374,36 @@ public InternalRow copy() { } } + /** + * Write this UnsafeRow's underlying bytes to the given OutputStream. + * + * @param out the stream to write to. + * @param writeBuffer a byte array for buffering chunks of off-heap data while writing to the + * output stream. If this row is backed by an on-heap byte array, then this + * buffer will not be used and may be null. + */ + public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException { + if (baseObject instanceof byte[]) { + int offsetInByteArray = (int) (PlatformDependent.BYTE_ARRAY_OFFSET - baseOffset); + out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes); + } else { + int dataRemaining = sizeInBytes; + long rowReadPosition = baseOffset; + while (dataRemaining > 0) { + int toTransfer = Math.min(writeBuffer.length, dataRemaining); + PlatformDependent.copyMemory( + baseObject, + rowReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + toTransfer); + out.write(writeBuffer, 0, toTransfer); + rowReadPosition += toTransfer; + dataRemaining -= toTransfer; + } + } + } + @Override public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 19503ed00056c..318550e5ed899 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -49,8 +49,16 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance { + /** + * Marks the end of a stream written with [[serializeStream()]]. + */ private[this] val EOF: Int = -1 + /** + * Serializes a stream of UnsafeRows. Within the stream, each record consists of a record + * length (stored as a 4-byte integer, written high byte first), followed by the record's bytes. + * The end of the stream is denoted by a record with the special length `EOF` (-1). + */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) private[this] val dOut: DataOutputStream = new DataOutputStream(out) @@ -59,32 +67,31 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst val row = value.asInstanceOf[UnsafeRow] assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool") dOut.writeInt(row.getSizeInBytes) - var dataRemaining: Int = row.getSizeInBytes - val baseObject = row.getBaseObject - var rowReadPosition: Long = row.getBaseOffset - while (dataRemaining > 0) { - val toTransfer: Int = Math.min(writeBuffer.length, dataRemaining) - PlatformDependent.copyMemory( - baseObject, - rowReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer) - out.write(writeBuffer, 0, toTransfer) - rowReadPosition += toTransfer - dataRemaining -= toTransfer - } + row.writeToStream(out, writeBuffer) this } + override def writeKey[T: ClassTag](key: T): SerializationStream = { + // The key is only needed on the map side when computing partition ids. It does not need to + // be shuffled. assert(key.isInstanceOf[Int]) this } - override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = + + override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { + // This method is never called by shuffle code. throw new UnsupportedOperationException - override def writeObject[T: ClassTag](t: T): SerializationStream = + } + + override def writeObject[T: ClassTag](t: T): SerializationStream = { + // This method is never called by shuffle code. throw new UnsupportedOperationException - override def flush(): Unit = dOut.flush() + } + + override def flush(): Unit = { + dOut.flush() + } + override def close(): Unit = { writeBuffer = null dOut.writeInt(EOF) @@ -95,6 +102,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def deserializeStream(in: InputStream): DeserializationStream = { new DeserializationStream { private[this] val dIn: DataInputStream = new DataInputStream(in) + // 1024 is a default buffer size; this buffer will grow to accommodate larger rows private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) private[this] var row: UnsafeRow = new UnsafeRow() private[this] var rowTuple: (Int, UnsafeRow) = (0, row) @@ -126,14 +134,40 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst } } } - override def asIterator: Iterator[Any] = throw new UnsupportedOperationException - override def readKey[T: ClassTag](): T = throw new UnsupportedOperationException - override def readValue[T: ClassTag](): T = throw new UnsupportedOperationException - override def readObject[T: ClassTag](): T = throw new UnsupportedOperationException - override def close(): Unit = dIn.close() + + override def asIterator: Iterator[Any] = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def readKey[T: ClassTag](): T = { + // We skipped serialization of the key in writeKey(), so just return a dummy value since + // this is going to be discarded anyways. + null.asInstanceOf[T] + } + + override def readValue[T: ClassTag](): T = { + val rowSize = dIn.readInt() + if (rowBuffer.length < rowSize) { + rowBuffer = new Array[Byte](rowSize) + } + ByteStreams.readFully(in, rowBuffer, 0, rowSize) + row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null) + row.asInstanceOf[T] + } + + override def readObject[T: ClassTag](): T = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def close(): Unit = { + dIn.close() + } } } + // These methods are never called by shuffle code. override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException override def deserialize[T: ClassTag](bytes: ByteBuffer): T = throw new UnsupportedOperationException diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala new file mode 100644 index 0000000000000..3854dc1b7a3d1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.ByteArrayOutputStream + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.memory.MemoryAllocator +import org.apache.spark.unsafe.types.UTF8String + +class UnsafeRowSuite extends SparkFunSuite { + test("writeToStream") { + val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) + val arrayBackedUnsafeRow: UnsafeRow = + UnsafeProjection.create(Seq(StringType, StringType, IntegerType)).apply(row) + assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) + val bytesFromArrayBackedRow: Array[Byte] = { + val baos = new ByteArrayOutputStream() + arrayBackedUnsafeRow.writeToStream(baos, null) + baos.toByteArray + } + val bytesFromOffheapRow: Array[Byte] = { + val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes) + try { + PlatformDependent.copyMemory( + arrayBackedUnsafeRow.getBaseObject, + arrayBackedUnsafeRow.getBaseOffset, + offheapRowPage.getBaseObject, + offheapRowPage.getBaseOffset, + arrayBackedUnsafeRow.getSizeInBytes + ) + val offheapUnsafeRow: UnsafeRow = new UnsafeRow() + offheapUnsafeRow.pointTo( + offheapRowPage.getBaseObject, + offheapRowPage.getBaseOffset, + 3, // num fields + arrayBackedUnsafeRow.getSizeInBytes, + null // object pool + ) + assert(offheapUnsafeRow.getBaseObject === null) + val baos = new ByteArrayOutputStream() + val writeBuffer = new Array[Byte](1024) + offheapUnsafeRow.writeToStream(baos, writeBuffer) + baos.toByteArray + } finally { + MemoryAllocator.UNSAFE.free(offheapRowPage) + } + } + + assert(bytesFromArrayBackedRow === bytesFromOffheapRow) + } +}