Skip to content

Commit

Permalink
Use putBytes/getRemoteBytes throughout.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Aug 19, 2014
1 parent 3670f00 commit 2d6a5fb
Showing 1 changed file with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.util.Random
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.ByteBufferInputStream

/**
* A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
Expand Down Expand Up @@ -76,8 +77,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](
private def writeBlocks(): Int = {
val blocks = TorrentBroadcast.blockifyObject(_value)
blocks.zipWithIndex.foreach { case (block, i) =>
// TODO: Use putBytes directly.
SparkEnv.get.blockManager.putSingle(
SparkEnv.get.blockManager.putBytes(
BroadcastBlockId(id, "piece" + i),
block,
StorageLevel.MEMORY_AND_DISK_SER,
Expand All @@ -87,21 +87,21 @@ private[spark] class TorrentBroadcast[T: ClassTag](
}

/** Fetch torrent blocks from the driver and/or other executors. */
private def readBlocks(): Array[Array[Byte]] = {
private def readBlocks(): Array[ByteBuffer] = {
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
// to the driver, so other executors can pull these chunks from this executor as well.
var numBlocksAvailable = 0
val blocks = new Array[Array[Byte]](numBlocks)
val blocks = new Array[ByteBuffer](numBlocks)

for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
SparkEnv.get.blockManager.getSingle(pieceId) match {
SparkEnv.get.blockManager.getRemoteBytes(pieceId) match {
case Some(x) =>
blocks(pid) = x.asInstanceOf[Array[Byte]]
blocks(pid) = x.asInstanceOf[ByteBuffer]
numBlocksAvailable += 1
SparkEnv.get.blockManager.putBytes(
pieceId,
ByteBuffer.wrap(blocks(pid)),
blocks(pid),
StorageLevel.MEMORY_AND_DISK_SER,
tellMaster = true)

Expand Down Expand Up @@ -182,7 +182,7 @@ private object TorrentBroadcast extends Logging {
initialized = false
}

def blockifyObject[T: ClassTag](obj: T): Array[Array[Byte]] = {
def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = {
// TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks
// so we don't need to do the extra memory copy.
val bos = new ByteArrayOutputStream()
Expand All @@ -193,24 +193,24 @@ private object TorrentBroadcast extends Logging {
val byteArray = bos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt
val blocks = new Array[Array[Byte]](numBlocks)
val blocks = new Array[ByteBuffer](numBlocks)

var blockId = 0
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
val tempByteArray = new Array[Byte](thisBlockSize)
bais.read(tempByteArray, 0, thisBlockSize)

blocks(blockId) = tempByteArray
blocks(blockId) = ByteBuffer.wrap(tempByteArray)
blockId += 1
}
bais.close()
blocks
}

def unBlockifyObject[T: ClassTag](blocks: Array[Array[Byte]]): T = {
def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = {
val is = new SequenceInputStream(
asJavaEnumeration(blocks.iterator.map(block => new ByteArrayInputStream(block))))
asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block))))
val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is

val ser = SparkEnv.get.serializer.newInstance()
Expand Down

0 comments on commit 2d6a5fb

Please sign in to comment.