Skip to content

Commit

Permalink
[SPARK-1385] Use existing code for JSON de/serialization of BlockId
Browse files Browse the repository at this point in the history
`BlockId.scala` offers a way to reconstruct a BlockId from a string through regex matching. `util/JsonProtocol.scala` duplicates this functionality by explicitly matching on the BlockId type.
With this PR, the de/serialization of BlockIds will go through the first (older) code path.

(Most of the line changes in this PR involve changing `==` to `===` in `JsonProtocolSuite.scala`)

Author: Andrew Or <[email protected]>

Closes #289 from andrewor14/blockid-json and squashes the following commits:

409d226 [Andrew Or] Simplify JSON de/serialization for BlockId
  • Loading branch information
andrewor14 authored and aarondav committed Apr 2, 2014
1 parent 11973a7 commit de8eefa
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 146 deletions.
77 changes: 2 additions & 75 deletions core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ private[spark] object JsonProtocol {
taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing)
val updatedBlocks = taskMetrics.updatedBlocks.map { blocks =>
JArray(blocks.toList.map { case (id, status) =>
("Block ID" -> blockIdToJson(id)) ~
("Block ID" -> id.toString) ~
("Status" -> blockStatusToJson(status))
})
}.getOrElse(JNothing)
Expand Down Expand Up @@ -284,35 +284,6 @@ private[spark] object JsonProtocol {
("Replication" -> storageLevel.replication)
}

def blockIdToJson(blockId: BlockId): JValue = {
val blockType = Utils.getFormattedClassName(blockId)
val json: JObject = blockId match {
case rddBlockId: RDDBlockId =>
("RDD ID" -> rddBlockId.rddId) ~
("Split Index" -> rddBlockId.splitIndex)
case shuffleBlockId: ShuffleBlockId =>
("Shuffle ID" -> shuffleBlockId.shuffleId) ~
("Map ID" -> shuffleBlockId.mapId) ~
("Reduce ID" -> shuffleBlockId.reduceId)
case broadcastBlockId: BroadcastBlockId =>
"Broadcast ID" -> broadcastBlockId.broadcastId
case broadcastHelperBlockId: BroadcastHelperBlockId =>
("Broadcast Block ID" -> blockIdToJson(broadcastHelperBlockId.broadcastId)) ~
("Helper Type" -> broadcastHelperBlockId.hType)
case taskResultBlockId: TaskResultBlockId =>
"Task ID" -> taskResultBlockId.taskId
case streamBlockId: StreamBlockId =>
("Stream ID" -> streamBlockId.streamId) ~
("Unique ID" -> streamBlockId.uniqueId)
case tempBlockId: TempBlockId =>
val uuid = UUIDToJson(tempBlockId.id)
"Temp ID" -> uuid
case testBlockId: TestBlockId =>
"Test ID" -> testBlockId.id
}
("Type" -> blockType) ~ json
}

def blockStatusToJson(blockStatus: BlockStatus): JValue = {
val storageLevel = storageLevelToJson(blockStatus.storageLevel)
("Storage Level" -> storageLevel) ~
Expand Down Expand Up @@ -513,7 +484,7 @@ private[spark] object JsonProtocol {
Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value =>
value.extract[List[JValue]].map { block =>
val id = blockIdFromJson(block \ "Block ID")
val id = BlockId((block \ "Block ID").extract[String])
val status = blockStatusFromJson(block \ "Status")
(id, status)
}
Expand Down Expand Up @@ -616,50 +587,6 @@ private[spark] object JsonProtocol {
StorageLevel(useDisk, useMemory, deserialized, replication)
}

def blockIdFromJson(json: JValue): BlockId = {
val rddBlockId = Utils.getFormattedClassName(RDDBlockId)
val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId)
val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId)
val broadcastHelperBlockId = Utils.getFormattedClassName(BroadcastHelperBlockId)
val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId)
val streamBlockId = Utils.getFormattedClassName(StreamBlockId)
val tempBlockId = Utils.getFormattedClassName(TempBlockId)
val testBlockId = Utils.getFormattedClassName(TestBlockId)

(json \ "Type").extract[String] match {
case `rddBlockId` =>
val rddId = (json \ "RDD ID").extract[Int]
val splitIndex = (json \ "Split Index").extract[Int]
new RDDBlockId(rddId, splitIndex)
case `shuffleBlockId` =>
val shuffleId = (json \ "Shuffle ID").extract[Int]
val mapId = (json \ "Map ID").extract[Int]
val reduceId = (json \ "Reduce ID").extract[Int]
new ShuffleBlockId(shuffleId, mapId, reduceId)
case `broadcastBlockId` =>
val broadcastId = (json \ "Broadcast ID").extract[Long]
new BroadcastBlockId(broadcastId)
case `broadcastHelperBlockId` =>
val broadcastBlockId =
blockIdFromJson(json \ "Broadcast Block ID").asInstanceOf[BroadcastBlockId]
val hType = (json \ "Helper Type").extract[String]
new BroadcastHelperBlockId(broadcastBlockId, hType)
case `taskResultBlockId` =>
val taskId = (json \ "Task ID").extract[Long]
new TaskResultBlockId(taskId)
case `streamBlockId` =>
val streamId = (json \ "Stream ID").extract[Int]
val uniqueId = (json \ "Unique ID").extract[Long]
new StreamBlockId(streamId, uniqueId)
case `tempBlockId` =>
val tempId = UUIDFromJson(json \ "Temp ID")
new TempBlockId(tempId)
case `testBlockId` =>
val testId = (json \ "Test ID").extract[String]
new TestBlockId(testId)
}
}

def blockStatusFromJson(json: JValue): BlockStatus = {
val storageLevel = storageLevelFromJson(json \ "Storage Level")
val memorySize = (json \ "Memory Size").extract[Long]
Expand Down
141 changes: 70 additions & 71 deletions core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ class JsonProtocolSuite extends FunSuite {
testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark"))
testBlockId(TaskResultBlockId(1L))
testBlockId(StreamBlockId(1, 2L))
testBlockId(TempBlockId(UUID.randomUUID()))
}


Expand Down Expand Up @@ -168,8 +167,8 @@ class JsonProtocolSuite extends FunSuite {
}

private def testBlockId(blockId: BlockId) {
val newBlockId = JsonProtocol.blockIdFromJson(JsonProtocol.blockIdToJson(blockId))
blockId == newBlockId
val newBlockId = BlockId(blockId.toString)
assert(blockId === newBlockId)
}


Expand All @@ -180,90 +179,90 @@ class JsonProtocolSuite extends FunSuite {
private def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) {
(event1, event2) match {
case (e1: SparkListenerStageSubmitted, e2: SparkListenerStageSubmitted) =>
assert(e1.properties == e2.properties)
assert(e1.properties === e2.properties)
assertEquals(e1.stageInfo, e2.stageInfo)
case (e1: SparkListenerStageCompleted, e2: SparkListenerStageCompleted) =>
assertEquals(e1.stageInfo, e2.stageInfo)
case (e1: SparkListenerTaskStart, e2: SparkListenerTaskStart) =>
assert(e1.stageId == e2.stageId)
assert(e1.stageId === e2.stageId)
assertEquals(e1.taskInfo, e2.taskInfo)
case (e1: SparkListenerTaskGettingResult, e2: SparkListenerTaskGettingResult) =>
assertEquals(e1.taskInfo, e2.taskInfo)
case (e1: SparkListenerTaskEnd, e2: SparkListenerTaskEnd) =>
assert(e1.stageId == e2.stageId)
assert(e1.taskType == e2.taskType)
assert(e1.stageId === e2.stageId)
assert(e1.taskType === e2.taskType)
assertEquals(e1.reason, e2.reason)
assertEquals(e1.taskInfo, e2.taskInfo)
assertEquals(e1.taskMetrics, e2.taskMetrics)
case (e1: SparkListenerJobStart, e2: SparkListenerJobStart) =>
assert(e1.jobId == e2.jobId)
assert(e1.properties == e2.properties)
assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 == i2))
assert(e1.jobId === e2.jobId)
assert(e1.properties === e2.properties)
assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 === i2))
case (e1: SparkListenerJobEnd, e2: SparkListenerJobEnd) =>
assert(e1.jobId == e2.jobId)
assert(e1.jobId === e2.jobId)
assertEquals(e1.jobResult, e2.jobResult)
case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) =>
assertEquals(e1.environmentDetails, e2.environmentDetails)
case (e1: SparkListenerBlockManagerAdded, e2: SparkListenerBlockManagerAdded) =>
assert(e1.maxMem == e2.maxMem)
assert(e1.maxMem === e2.maxMem)
assertEquals(e1.blockManagerId, e2.blockManagerId)
case (e1: SparkListenerBlockManagerRemoved, e2: SparkListenerBlockManagerRemoved) =>
assertEquals(e1.blockManagerId, e2.blockManagerId)
case (e1: SparkListenerUnpersistRDD, e2: SparkListenerUnpersistRDD) =>
assert(e1.rddId == e2.rddId)
assert(e1.rddId === e2.rddId)
case (SparkListenerShutdown, SparkListenerShutdown) =>
case _ => fail("Events don't match in types!")
}
}

private def assertEquals(info1: StageInfo, info2: StageInfo) {
assert(info1.stageId == info2.stageId)
assert(info1.name == info2.name)
assert(info1.numTasks == info2.numTasks)
assert(info1.submissionTime == info2.submissionTime)
assert(info1.completionTime == info2.completionTime)
assert(info1.emittedTaskSizeWarning == info2.emittedTaskSizeWarning)
assert(info1.stageId === info2.stageId)
assert(info1.name === info2.name)
assert(info1.numTasks === info2.numTasks)
assert(info1.submissionTime === info2.submissionTime)
assert(info1.completionTime === info2.completionTime)
assert(info1.emittedTaskSizeWarning === info2.emittedTaskSizeWarning)
assertEquals(info1.rddInfo, info2.rddInfo)
}

private def assertEquals(info1: RDDInfo, info2: RDDInfo) {
assert(info1.id == info2.id)
assert(info1.name == info2.name)
assert(info1.numPartitions == info2.numPartitions)
assert(info1.numCachedPartitions == info2.numCachedPartitions)
assert(info1.memSize == info2.memSize)
assert(info1.diskSize == info2.diskSize)
assert(info1.id === info2.id)
assert(info1.name === info2.name)
assert(info1.numPartitions === info2.numPartitions)
assert(info1.numCachedPartitions === info2.numCachedPartitions)
assert(info1.memSize === info2.memSize)
assert(info1.diskSize === info2.diskSize)
assertEquals(info1.storageLevel, info2.storageLevel)
}

private def assertEquals(level1: StorageLevel, level2: StorageLevel) {
assert(level1.useDisk == level2.useDisk)
assert(level1.useMemory == level2.useMemory)
assert(level1.deserialized == level2.deserialized)
assert(level1.replication == level2.replication)
assert(level1.useDisk === level2.useDisk)
assert(level1.useMemory === level2.useMemory)
assert(level1.deserialized === level2.deserialized)
assert(level1.replication === level2.replication)
}

private def assertEquals(info1: TaskInfo, info2: TaskInfo) {
assert(info1.taskId == info2.taskId)
assert(info1.index == info2.index)
assert(info1.launchTime == info2.launchTime)
assert(info1.executorId == info2.executorId)
assert(info1.host == info2.host)
assert(info1.taskLocality == info2.taskLocality)
assert(info1.gettingResultTime == info2.gettingResultTime)
assert(info1.finishTime == info2.finishTime)
assert(info1.failed == info2.failed)
assert(info1.serializedSize == info2.serializedSize)
assert(info1.taskId === info2.taskId)
assert(info1.index === info2.index)
assert(info1.launchTime === info2.launchTime)
assert(info1.executorId === info2.executorId)
assert(info1.host === info2.host)
assert(info1.taskLocality === info2.taskLocality)
assert(info1.gettingResultTime === info2.gettingResultTime)
assert(info1.finishTime === info2.finishTime)
assert(info1.failed === info2.failed)
assert(info1.serializedSize === info2.serializedSize)
}

private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) {
assert(metrics1.hostname == metrics2.hostname)
assert(metrics1.executorDeserializeTime == metrics2.executorDeserializeTime)
assert(metrics1.resultSize == metrics2.resultSize)
assert(metrics1.jvmGCTime == metrics2.jvmGCTime)
assert(metrics1.resultSerializationTime == metrics2.resultSerializationTime)
assert(metrics1.memoryBytesSpilled == metrics2.memoryBytesSpilled)
assert(metrics1.diskBytesSpilled == metrics2.diskBytesSpilled)
assert(metrics1.hostname === metrics2.hostname)
assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime)
assert(metrics1.resultSize === metrics2.resultSize)
assert(metrics1.jvmGCTime === metrics2.jvmGCTime)
assert(metrics1.resultSerializationTime === metrics2.resultSerializationTime)
assert(metrics1.memoryBytesSpilled === metrics2.memoryBytesSpilled)
assert(metrics1.diskBytesSpilled === metrics2.diskBytesSpilled)
assertOptionEquals(
metrics1.shuffleReadMetrics, metrics2.shuffleReadMetrics, assertShuffleReadEquals)
assertOptionEquals(
Expand All @@ -272,31 +271,31 @@ class JsonProtocolSuite extends FunSuite {
}

private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) {
assert(metrics1.shuffleFinishTime == metrics2.shuffleFinishTime)
assert(metrics1.totalBlocksFetched == metrics2.totalBlocksFetched)
assert(metrics1.remoteBlocksFetched == metrics2.remoteBlocksFetched)
assert(metrics1.localBlocksFetched == metrics2.localBlocksFetched)
assert(metrics1.fetchWaitTime == metrics2.fetchWaitTime)
assert(metrics1.remoteBytesRead == metrics2.remoteBytesRead)
assert(metrics1.shuffleFinishTime === metrics2.shuffleFinishTime)
assert(metrics1.totalBlocksFetched === metrics2.totalBlocksFetched)
assert(metrics1.remoteBlocksFetched === metrics2.remoteBlocksFetched)
assert(metrics1.localBlocksFetched === metrics2.localBlocksFetched)
assert(metrics1.fetchWaitTime === metrics2.fetchWaitTime)
assert(metrics1.remoteBytesRead === metrics2.remoteBytesRead)
}

private def assertEquals(metrics1: ShuffleWriteMetrics, metrics2: ShuffleWriteMetrics) {
assert(metrics1.shuffleBytesWritten == metrics2.shuffleBytesWritten)
assert(metrics1.shuffleWriteTime == metrics2.shuffleWriteTime)
assert(metrics1.shuffleBytesWritten === metrics2.shuffleBytesWritten)
assert(metrics1.shuffleWriteTime === metrics2.shuffleWriteTime)
}

private def assertEquals(bm1: BlockManagerId, bm2: BlockManagerId) {
assert(bm1.executorId == bm2.executorId)
assert(bm1.host == bm2.host)
assert(bm1.port == bm2.port)
assert(bm1.nettyPort == bm2.nettyPort)
assert(bm1.executorId === bm2.executorId)
assert(bm1.host === bm2.host)
assert(bm1.port === bm2.port)
assert(bm1.nettyPort === bm2.nettyPort)
}

private def assertEquals(result1: JobResult, result2: JobResult) {
(result1, result2) match {
case (JobSucceeded, JobSucceeded) =>
case (r1: JobFailed, r2: JobFailed) =>
assert(r1.failedStageId == r2.failedStageId)
assert(r1.failedStageId === r2.failedStageId)
assertEquals(r1.exception, r2.exception)
case _ => fail("Job results don't match in types!")
}
Expand All @@ -307,13 +306,13 @@ class JsonProtocolSuite extends FunSuite {
case (Success, Success) =>
case (Resubmitted, Resubmitted) =>
case (r1: FetchFailed, r2: FetchFailed) =>
assert(r1.shuffleId == r2.shuffleId)
assert(r1.mapId == r2.mapId)
assert(r1.reduceId == r2.reduceId)
assert(r1.shuffleId === r2.shuffleId)
assert(r1.mapId === r2.mapId)
assert(r1.reduceId === r2.reduceId)
assertEquals(r1.bmAddress, r2.bmAddress)
case (r1: ExceptionFailure, r2: ExceptionFailure) =>
assert(r1.className == r2.className)
assert(r1.description == r2.description)
assert(r1.className === r2.className)
assert(r1.description === r2.description)
assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals)
assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals)
case (TaskResultLost, TaskResultLost) =>
Expand All @@ -329,13 +328,13 @@ class JsonProtocolSuite extends FunSuite {
details2: Map[String, Seq[(String, String)]]) {
details1.zip(details2).foreach {
case ((key1, values1: Seq[(String, String)]), (key2, values2: Seq[(String, String)])) =>
assert(key1 == key2)
values1.zip(values2).foreach { case (v1, v2) => assert(v1 == v2) }
assert(key1 === key2)
values1.zip(values2).foreach { case (v1, v2) => assert(v1 === v2) }
}
}

private def assertEquals(exception1: Exception, exception2: Exception) {
assert(exception1.getMessage == exception2.getMessage)
assert(exception1.getMessage === exception2.getMessage)
assertSeqEquals(
exception1.getStackTrace,
exception2.getStackTrace,
Expand All @@ -344,11 +343,11 @@ class JsonProtocolSuite extends FunSuite {

private def assertJsonStringEquals(json1: String, json2: String) {
val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "")
formatJsonString(json1) == formatJsonString(json2)
formatJsonString(json1) === formatJsonString(json2)
}

private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) {
assert(seq1.length == seq2.length)
assert(seq1.length === seq2.length)
seq1.zip(seq2).foreach { case (t1, t2) =>
assertEquals(t1, t2)
}
Expand Down Expand Up @@ -389,11 +388,11 @@ class JsonProtocolSuite extends FunSuite {
}

private def assertBlockEquals(b1: (BlockId, BlockStatus), b2: (BlockId, BlockStatus)) {
assert(b1 == b2)
assert(b1 === b2)
}

private def assertStackTraceElementEquals(ste1: StackTraceElement, ste2: StackTraceElement) {
assert(ste1 == ste2)
assert(ste1 === ste2)
}


Expand Down

0 comments on commit de8eefa

Please sign in to comment.