Skip to content

Commit

Permalink
[CARMEL-6185] Expose row count for RepeatableIterator (#1074)
Browse files Browse the repository at this point in the history
* [CARMEL-6185] Expose row count for RepeatableIterator

* fix code style

* Fix code style

* Fix UT

* Update code

* Update code
  • Loading branch information
wakun authored and GitHub Enterprise committed Oct 14, 2022
1 parent e0ff530 commit ee77923
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ private[spark] object InternalAccumulator {
val UPDATED_BLOCK_STATUSES = METRICS_PREFIX + "updatedBlockStatuses"
val PRUNED_STATS = "index.prunedStats"
val TEST_ACCUM = METRICS_PREFIX + "testAccumulator"
val RECORDS_OUTPUT = OUTPUT_METRICS_PREFIX + "recordsOutput"

// scalastyle:off

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class TaskMetrics private[spark] () extends Serializable {
private val _peakExecutionMemory = new LongAccumulator
private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)]
private val _prunedStats = new PrunedMetricsAccum
private val _recordsOutput = new LongAccumulator

def prunedStats: PrunedMetricsAccum = _prunedStats
/**
Expand Down Expand Up @@ -113,6 +114,11 @@ class TaskMetrics private[spark] () extends Serializable {
*/
def peakExecutionMemory: Long = _peakExecutionMemory.sum

/**
* Total number of records output.
*/
def recordsOutput: Long = _recordsOutput.sum

/**
* Storage statuses of any blocks that have been updated as a result of this task.
*
Expand Down Expand Up @@ -152,6 +158,7 @@ class TaskMetrics private[spark] () extends Serializable {
private[spark] def setPrunedStats(v: List[PrunedStats]): Unit = {
_prunedStats.setValue(v)
}
private[spark] def setRecordsOutput(v: Long): Unit = _recordsOutput.setValue(v)

/**
* Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted
Expand Down Expand Up @@ -226,6 +233,7 @@ class TaskMetrics private[spark] () extends Serializable {
PEAK_EXECUTION_MEMORY -> _peakExecutionMemory,
UPDATED_BLOCK_STATUSES -> _updatedBlockStatuses,
PRUNED_STATS -> _prunedStats,
RECORDS_OUTPUT -> _recordsOutput,
shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched,
shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched,
shuffleRead.REMOTE_BYTES_READ -> shuffleReadMetrics._remoteBytesRead,
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,12 @@ abstract class RDD[T: ClassTag](
* expansion rate = (number of output rows in the task) / (number of input rows in task).
*/
final def expansionLimitedIterator(split: Partition, context: TaskContext): Iterator[T] = {
val innerItrator = iterator(split, context)
val innerIterator = iterator(split, context)
if (maxExpandRate > 0) {
new Iterator[T] {
private var output = 0
override def hasNext: Boolean = {
innerItrator.hasNext
innerIterator.hasNext
}
override def next(): T = {
output += 1
Expand All @@ -345,11 +345,11 @@ abstract class RDD[T: ClassTag](
}
output = 0
}
innerItrator.next()
innerIterator.next()
}
}
} else {
innerItrator
innerIterator
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1942,7 +1942,7 @@ private[spark] class DAGScheduler(
// taskSucceeded runs some user code that might throw an exception. Make sure
// we are resilient against that.
try {
job.listener.taskSucceeded(rt.outputId, event.result)
job.listener.taskSucceeded(rt.outputId, event.result, event.taskMetrics)
} catch {
case e: Throwable if !Utils.isFatalError(e) =>
// TODO: Perhaps we want to mark the resultStage as failed?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging

/**
Expand All @@ -48,6 +49,7 @@ private[spark] class IterableJobWaiter[U: ClassTag, R](
// to hold the result data as spilled files or in memory array
private var spilledResultData: Option[Array[SpilledPartitionResult]] = None
private val resultData: Array[U] = new Array[U](totalTasks)
private var allRowCount: Long = 0L

// indicate whether the in memory result data has been cleaned after
// the result data is spilled to disk
Expand All @@ -66,7 +68,11 @@ private[spark] class IterableJobWaiter[U: ClassTag, R](
dagScheduler.cancelJob(jobId, None)
}

override def taskSucceeded(index: Int, result: Any): Unit = {
override def taskSucceeded(index: Int, result: Any): Unit =
taskSucceeded(index, result, new TaskMetrics)

override def taskSucceeded(index: Int, result: Any, taskMetrics: TaskMetrics): Unit = {
allRowCount += taskMetrics.recordsOutput
result match {
case spilledPartitionResult: Array[SpilledPartitionResult] =>
spilledResultData = Some(spilledPartitionResult)
Expand Down Expand Up @@ -109,12 +115,12 @@ private[spark] class IterableJobWaiter[U: ClassTag, R](
if (spilledResultData.nonEmpty) {
logInfo(s"Return result as a SpilledResultIterator for job $jobId " +
s"with files ${spilledResultData.get.map(_.file.getPath).mkString(",")}")
SpilledResultIterator[U, R](spilledResultData.get, resultConverter,
SpilledResultIterator[U, R](spilledResultData.get, resultConverter, allRowCount,
dagScheduler.sc.conf.getBoolean("spark.sql.thriftserver.cleanShareResultFiles", false),
dagScheduler.sc.conf.getBoolean("spark.sql.thriftserver.shareResult", true))
} else {
logInfo(s"Return result as a SimpleRepeatableIterator for job $jobId")
SimpleRepeatableIterator(resultData, resultConverter)
SimpleRepeatableIterator(resultData, resultConverter, allRowCount)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@

package org.apache.spark.scheduler

import org.apache.spark.executor.TaskMetrics

/**
* Interface used to listen for job completion or failure events after submitting a job to the
* DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole
* job fails (and no further taskSucceeded events will happen).
*/
private[spark] trait JobListener {
def taskSucceeded(index: Int, result: Any): Unit

def taskSucceeded(index: Int, result: Any, taskMetrics: TaskMetrics): Unit =
taskSucceeded(index, result)
def jobFailed(exception: Exception): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -180,22 +180,28 @@ private[spark] case class SpilledPartitionResult(
/**
* the interface for iterator which supports read from the start again
*/
private[spark] trait RepeatableIterator[T] extends Iterator[T] {
private[spark] abstract class RepeatableIterator[T](_rowCount: Long) extends Iterator[T] {

def backToStart(): Unit

def close(): Unit

def copy(): RepeatableIterator[T]

def rowCount(): Long = _rowCount

// Do not call this method, length will < 0 if rowCount > Int.MAX_VALUE
override def length: Int = rowCount.toInt
}

/**
* the simple iterator implementation to read from the in memory data
*/
private[spark] case class SimpleRepeatableIterator[T, U](
originData: Array[U],
resultConverter: U => Iterator[T])
extends RepeatableIterator[T] {
resultConverter: U => Iterator[T],
_rowCount: Long)
extends RepeatableIterator[T](_rowCount) {

private var it: Iterator[T] = originData.iterator.flatMap(resultConverter)

Expand All @@ -210,24 +216,29 @@ private[spark] case class SimpleRepeatableIterator[T, U](
override def close(): Unit = {
}

// length calculation is time consuming
override def length: Int = {
originData.iterator.flatMap(resultConverter).length
}

override def copy(): RepeatableIterator[T] = {
new SimpleRepeatableIterator[T, U](originData, resultConverter)
new SimpleRepeatableIterator[T, U](originData, resultConverter, rowCount)
}
}

/**
* The iterator implementation to read from spilled files
* data of spilledResults: Array[SpilledPartitionResult]
* file blockId offset length
* /data/yarn/tmp/file1, "temp_local_001", 0, 100
* /data/yarn/tmp/file1, "temp_local_002", 100, 200
* /data/yarn/tmp/file2, "temp_local_003", 0, 400
*
* nextBatchStream() will clean the temp file and then read a new SpilledPartitionResult.
* readNextBatch() will convert the partition result to currentBatch: Iterator[R]
*/
private[spark] case class SpilledResultIterator[U, R](
spilledResults: Array[SpilledPartitionResult],
converter: U => Iterator[R],
_rowCount: Long,
cleanShareResultFiles: Boolean = false,
override val isTraversableAgain: Boolean) extends RepeatableIterator[R] with Logging {
override val isTraversableAgain: Boolean)
extends RepeatableIterator[R](_rowCount) with Logging {

private val serializer = SparkEnv.get.serializer.newInstance()
private val serializerManager = SparkEnv.get.serializerManager
Expand Down Expand Up @@ -378,7 +389,7 @@ private[spark] case class SpilledResultIterator[U, R](
}

override def copy(): RepeatableIterator[R] = {
new SpilledResultIterator[U, R](spilledResults, converter, cleanShareResultFiles,
new SpilledResultIterator[U, R](spilledResults, converter, rowCount, cleanShareResultFiles,
isTraversableAgain)
}
}
Expand Down
35 changes: 21 additions & 14 deletions core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2093,104 +2093,111 @@ private[spark] object JsonProtocolSuite extends Assertions {
| },
| {
| "ID": 11,
| "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}",
| "Name": "${RECORDS_OUTPUT}",
| "Update": 0,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 12,
| "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}",
| "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}",
| "Update": 0,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 13,
| "Name": "${shuffleRead.REMOTE_BYTES_READ}",
| "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}",
| "Update": 0,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 14,
| "Name": "${shuffleRead.REMOTE_BYTES_READ_TO_DISK}",
| "Name": "${shuffleRead.REMOTE_BYTES_READ}",
| "Update": 0,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 15,
| "Name": "${shuffleRead.LOCAL_BYTES_READ}",
| "Name": "${shuffleRead.REMOTE_BYTES_READ_TO_DISK}",
| "Update": 0,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 16,
| "Name": "${shuffleRead.FETCH_WAIT_TIME}",
| "Name": "${shuffleRead.LOCAL_BYTES_READ}",
| "Update": 0,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 17,
| "Name": "${shuffleRead.RECORDS_READ}",
| "Name": "${shuffleRead.FETCH_WAIT_TIME}",
| "Update": 0,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 18,
| "Name": "${shuffleWrite.BYTES_WRITTEN}",
| "Name": "${shuffleRead.RECORDS_READ}",
| "Update": 0,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 19,
| "Name": "${shuffleWrite.RECORDS_WRITTEN}",
| "Name": "${shuffleWrite.BYTES_WRITTEN}",
| "Update": 0,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 20,
| "Name": "${shuffleWrite.WRITE_TIME}",
| "Name": "${shuffleWrite.RECORDS_WRITTEN}",
| "Update": 0,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 21,
| "Name": "${shuffleWrite.WRITE_TIME}",
| "Update": 0,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 22,
| "Name": "${input.BYTES_READ}",
| "Update": 2100,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 22,
| "ID": 23,
| "Name": "${input.RECORDS_READ}",
| "Update": 21,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 23,
| "ID": 24,
| "Name": "${output.BYTES_WRITTEN}",
| "Update": 1200,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 24,
| "ID": 25,
| "Name": "${output.RECORDS_WRITTEN}",
| "Update": 12,
| "Internal": true,
| "Count Failed Values": true
| },
| {
| "ID": 25,
| "ID": 26,
| "Name": "$TEST_ACCUM",
| "Update": 0,
| "Internal": true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.AbstractIterator
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

import org.apache.spark.{broadcast, SparkEnv, TaskKilledException}
import org.apache.spark.{broadcast, SparkEnv, TaskContext, TaskKilledException}
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd.{RDD, RDDOperationScope}
Expand Down Expand Up @@ -362,6 +362,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
out.writeInt(-1)
out.flush()
out.close()
TaskContext.get().taskMetrics().setRecordsOutput(count)
Iterator((count, bos.toByteArray))
}
}
Expand Down Expand Up @@ -544,7 +545,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
throw new IllegalArgumentException(s"Limit cannot exceed threshold ${conf.limitMaxRows}")
}
logInfo(s"Return limit result as a SimpleRepeatableIterator.")
SimpleRepeatableIterator[R, InternalRow](executeTake(n), row => Iterator(proj(row)))
val array = executeTake(n)
SimpleRepeatableIterator[R, InternalRow](array, row => Iterator(proj(row)), array.length)
}

/**
Expand Down
Loading

0 comments on commit ee77923

Please sign in to comment.