Skip to content

Commit

Permalink
Clean up @sarutak's PR apache#1490 for [SPARK-2583]: ConnectionManage…
Browse files Browse the repository at this point in the history
…r error reporting

Use Futures to signal failures, rather than exposing empty messages to user code.
  • Loading branch information
JoshRosen committed Aug 4, 2014
1 parent 7399c6b commit f1cd1bb
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network

import java.io.IOException
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
Expand All @@ -41,16 +42,26 @@ import org.apache.spark.util.{SystemClock, Utils}
private[spark] class ConnectionManager(port: Int, conf: SparkConf,
securityManager: SecurityManager) extends Logging {

/**
* Used by sendMessageReliably to track messages being sent.
* @param message the message that was sent
* @param connectionManagerId the connection manager that sent this message
* @param completionHandler callback that's invoked when the send has completed or failed
*/
class MessageStatus(
val message: Message,
val connectionManagerId: ConnectionManagerId,
completionHandler: MessageStatus => Unit) {

/** This is non-None if message has been ack'd */
var ackMessage: Option[Message] = None
var attempted = false
var acked = false

def markDone() { completionHandler(this) }
def markDone(ackMessage: Option[Message]) {
this.synchronized {
this.ackMessage = ackMessage
completionHandler(this)
}
}
}

private val selector = SelectorProvider.provider.openSelector()
Expand Down Expand Up @@ -434,11 +445,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
.foreach(status => {
logInfo("Notifying " + status)
status.synchronized {
status.attempted = true
status.acked = false
status.markDone()
}
status.markDone(None)
})

messageStatuses.retain((i, status) => {
Expand Down Expand Up @@ -467,11 +474,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
for (s <- messageStatuses.values
if s.connectionManagerId == sendingConnectionManagerId) {
logInfo("Notifying " + s)
s.synchronized {
s.attempted = true
s.acked = false
s.markDone()
}
s.markDone(None)
}

messageStatuses.retain((i, status) => {
Expand Down Expand Up @@ -539,13 +542,13 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
securityMsg.getConnectionId.toString)
val message = securityMsgResp.toBufferMessage
if (message == null) throw new Exception("Error creating security message")
if (message == null) throw new IOException("Error creating security message")
sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
} catch {
case e: Exception => {
logError("Error handling sasl client authentication", e)
waitingConn.close()
throw new Exception("Error evaluating sasl response: " + e)
throw new IOException("Error evaluating sasl response: " + e)
}
}
}
Expand Down Expand Up @@ -653,12 +656,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
}
}
}
sentMessageStatus.synchronized {
sentMessageStatus.ackMessage = Some(message)
sentMessageStatus.attempted = true
sentMessageStatus.acked = true
sentMessageStatus.markDone()
}
sentMessageStatus.markDone(Some(message))
} else {
var ackMessage : Option[Message] = None
try {
Expand All @@ -681,7 +679,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
}
} catch {
case e: Exception => {
logError(s"Exception was thrown during processing message", e)
logError(s"Exception was thrown while processing message", e)
val m = Message.createBufferMessage(bufferMessage.id)
m.hasError = true
ackMessage = Some(m)
Expand Down Expand Up @@ -802,11 +800,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
case Some(msgStatus) => {
messageStatuses -= message.id
logInfo("Notifying " + msgStatus.connectionManagerId)
msgStatus.synchronized {
msgStatus.attempted = true
msgStatus.acked = false
msgStatus.markDone()
}
msgStatus.markDone(None)
}
case None => {
logError("no messageStatus for failed message id: " + message.id)
Expand All @@ -825,11 +819,28 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
selector.wakeup()
}

/**
* Send a message and block until an acknowldgment is received or an error occurs.
* @param connectionManagerId the message's destination
* @param message the message being sent
* @return a Future that either returns the acknowledgment message or captures an exception.
*/
def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message)
: Future[Option[Message]] = {
val promise = Promise[Option[Message]]
val status = new MessageStatus(
message, connectionManagerId, s => promise.success(s.ackMessage))
: Future[Message] = {
val promise = Promise[Message]()
val status = new MessageStatus(message, connectionManagerId, s => {
s.ackMessage match {
case None => // Indicates a failure where we either never sent or never got ACK'd
promise.failure(new IOException("sendMessageReliably failed without being ACK'd"))
case Some(ackMessage) =>
if (ackMessage.hasError) {
promise.failure(
new IOException("sendMessageReliably failed with ACK that signalled a remote error"))
} else {
promise.success(ackMessage)
}
}
})
messageStatuses.synchronized {
messageStatuses += ((message.id, status))
}
Expand All @@ -838,7 +849,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
}

def sendMessageReliablySync(connectionManagerId: ConnectionManagerId,
message: Message): Option[Message] = {
message: Message): Message = {
Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)
}

Expand All @@ -864,6 +875,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,


private[spark] object ConnectionManager {
import ExecutionContext.Implicits.global

def main(args: Array[String]) {
val conf = new SparkConf
Expand Down Expand Up @@ -919,8 +931,10 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
f.onFailure {
case e => println("Failed due to " + e)
}
Await.ready(f, 1 second)
})
val finishTime = System.currentTimeMillis

Expand Down Expand Up @@ -954,8 +968,10 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
f.onFailure {
case e => println("Failed due to " + e)
}
Await.ready(f, 1 second)
})
val finishTime = System.currentTimeMillis

Expand Down Expand Up @@ -984,8 +1000,10 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
f.onFailure {
case e => println("Failed due to " + e)
}
Await.ready(f, 1 second)
})
val finishTime = System.currentTimeMillis
Thread.sleep(1000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.network

import java.nio.ByteBuffer
import scala.util.Try
import org.apache.spark.{SecurityManager, SparkConf}

private[spark] object SenderTest {
Expand Down Expand Up @@ -51,7 +52,7 @@ private[spark] object SenderTest {
val dataMessage = Message.createBufferMessage(buffer.duplicate)
val startTime = System.currentTimeMillis
/* println("Started timer at " + startTime) */
val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage)
val responseStr = Try(manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage))
.map { response =>
val buffer = response.asInstanceOf[BufferMessage].buffers(0)
new String(buffer.array, "utf-8")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet
import scala.collection.mutable.Queue
import scala.util.{Failure, Success}

import io.netty.buffer.ByteBuf

Expand Down Expand Up @@ -118,31 +119,24 @@ object BlockFetcherIterator {
bytesInFlight += req.size
val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
future.onSuccess {
case Some(message) => {
future.onComplete {
case Success(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
if (bufferMessage.hasError) {
logError("Could not get block(s) from " + cmId)
for ((blockId, size) <- req.blocks) {
results.put(new FetchResult(blockId, -1, null))
}
} else {
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
throw new SparkException(
"Unexpected message " + blockMessage.getType + " received from " + cmId)
}
val blockId = blockMessage.getId
val networkSize = blockMessage.getData.limit()
results.put(new FetchResult(blockId, sizeMap(blockId),
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
_remoteBytesRead += networkSize
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
throw new SparkException(
"Unexpected message " + blockMessage.getType + " received from " + cmId)
}
val blockId = blockMessage.getId
val networkSize = blockMessage.getData.limit()
results.put(new FetchResult(blockId, sizeMap(blockId),
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
_remoteBytesRead += networkSize
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
case None => {
case Failure(exception) => {
logError("Could not get block(s) from " + cmId)
for ((blockId, size) <- req.blocks) {
results.put(new FetchResult(blockId, -1, null))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.apache.spark.Logging
import org.apache.spark.network._
import org.apache.spark.util.Utils

import scala.util.{Failure, Success, Try}

/**
* A network interface for BlockManager. Each slave should have one
* BlockManagerWorker.
Expand Down Expand Up @@ -115,28 +117,28 @@ private[spark] object BlockManagerWorker extends Logging {
val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromPutBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val resultMessage = connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage)
resultMessage.isDefined
val resultMessage = Try(connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage))
resultMessage.isSuccess
}

def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
val blockManager = blockManagerWorker.blockManager
val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromGetBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val responseMessage = connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage)
val responseMessage = Try(connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage))
responseMessage match {
case Some(message) => {
case Success(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
logDebug("Response message received " + bufferMessage)
BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => {
logDebug("Found " + blockMessage)
return blockMessage.getData
})
}
case None => logDebug("No response message received")
case Failure(exception) => logDebug("No response message received")
}
null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.scalatest.FunSuite
import scala.concurrent.{Await, TimeoutException}
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Try

/**
* Test the ConnectionManager with various security settings.
Expand Down Expand Up @@ -209,7 +210,6 @@ class ConnectionManagerSuite extends FunSuite {
}).foreach(f => {
try {
val g = Await.result(f, 1 second)
if (!g.isDefined) assert(false) else assert(true)
} catch {
case e: Exception => {
assert(false)
Expand Down Expand Up @@ -240,9 +240,8 @@ class ConnectionManagerSuite extends FunSuite {

val future = manager.sendMessageReliably(managerServer.id, bufferMessage)

val message = Await.result(future, 1 second)
assert(message.isDefined)
assert(message.get.hasError)
val message = Try(Await.result(future, 1 second))
assert(message.isFailure)

manager.stop()
managerServer.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
val f = future {
val message = Message.createBufferMessage(0)
message.hasError = true
val someMessage = Some(message)
someMessage
message
}
when(connManager.sendMessageReliably(any(),
any())).thenReturn(f)
Expand Down Expand Up @@ -204,10 +203,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
buffer.flip()
arrayBuffer += buffer

val someMessage = Some(Message.createBufferMessage(arrayBuffer))

val f = future {
someMessage
Message.createBufferMessage(arrayBuffer)
}
when(connManager.sendMessageReliably(any(),
any())).thenReturn(f)
Expand Down

0 comments on commit f1cd1bb

Please sign in to comment.