From 2d6a5fb19119402254f277bdfaba373d32444612 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 19 Aug 2014 00:43:00 -0700 Subject: [PATCH] Use putBytes/getRemoteBytes throughout. --- .../spark/broadcast/TorrentBroadcast.scala | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 19b1b1ca2d5ca..0ceeb1a524905 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -27,6 +27,7 @@ import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} +import org.apache.spark.util.ByteBufferInputStream /** * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. @@ -76,8 +77,7 @@ private[spark] class TorrentBroadcast[T: ClassTag]( private def writeBlocks(): Int = { val blocks = TorrentBroadcast.blockifyObject(_value) blocks.zipWithIndex.foreach { case (block, i) => - // TODO: Use putBytes directly. - SparkEnv.get.blockManager.putSingle( + SparkEnv.get.blockManager.putBytes( BroadcastBlockId(id, "piece" + i), block, StorageLevel.MEMORY_AND_DISK_SER, @@ -87,21 +87,21 @@ private[spark] class TorrentBroadcast[T: ClassTag]( } /** Fetch torrent blocks from the driver and/or other executors. */ - private def readBlocks(): Array[Array[Byte]] = { + private def readBlocks(): Array[ByteBuffer] = { // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported // to the driver, so other executors can pull these chunks from this executor as well. var numBlocksAvailable = 0 - val blocks = new Array[Array[Byte]](numBlocks) + val blocks = new Array[ByteBuffer](numBlocks) for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { val pieceId = BroadcastBlockId(id, "piece" + pid) - SparkEnv.get.blockManager.getSingle(pieceId) match { + SparkEnv.get.blockManager.getRemoteBytes(pieceId) match { case Some(x) => - blocks(pid) = x.asInstanceOf[Array[Byte]] + blocks(pid) = x.asInstanceOf[ByteBuffer] numBlocksAvailable += 1 SparkEnv.get.blockManager.putBytes( pieceId, - ByteBuffer.wrap(blocks(pid)), + blocks(pid), StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) @@ -182,7 +182,7 @@ private object TorrentBroadcast extends Logging { initialized = false } - def blockifyObject[T: ClassTag](obj: T): Array[Array[Byte]] = { + def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = { // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks // so we don't need to do the extra memory copy. val bos = new ByteArrayOutputStream() @@ -193,7 +193,7 @@ private object TorrentBroadcast extends Logging { val byteArray = bos.toByteArray val bais = new ByteArrayInputStream(byteArray) val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt - val blocks = new Array[Array[Byte]](numBlocks) + val blocks = new Array[ByteBuffer](numBlocks) var blockId = 0 for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { @@ -201,16 +201,16 @@ private object TorrentBroadcast extends Logging { val tempByteArray = new Array[Byte](thisBlockSize) bais.read(tempByteArray, 0, thisBlockSize) - blocks(blockId) = tempByteArray + blocks(blockId) = ByteBuffer.wrap(tempByteArray) blockId += 1 } bais.close() blocks } - def unBlockifyObject[T: ClassTag](blocks: Array[Array[Byte]]): T = { + def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = { val is = new SequenceInputStream( - asJavaEnumeration(blocks.iterator.map(block => new ByteArrayInputStream(block)))) + asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is val ser = SparkEnv.get.serializer.newInstance()