Skip to content

Commit

Permalink
added console sink for continuous processing
Browse files Browse the repository at this point in the history
  • Loading branch information
tdas committed Jan 18, 2018
1 parent 1c76a91 commit 6f69669
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ package org.apache.spark.sql.execution.streaming

import java.util.Optional

import scala.collection.JavaConverters._

import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter
import org.apache.spark.sql.execution.streaming.sources.{ConsoleContinuousWriter, ConsoleMicroBatchWriter, ConsoleWriter}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options}
import org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport
import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport}
import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
Expand All @@ -37,16 +36,25 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame)

class ConsoleSinkProvider extends DataSourceV2
with MicroBatchWriteSupport
with ContinuousWriteSupport
with DataSourceRegister
with CreatableRelationProvider {

override def createMicroBatchWriter(
queryId: String,
epochId: Long,
batchId: Long,
schema: StructType,
mode: OutputMode,
options: DataSourceV2Options): Optional[DataSourceV2Writer] = {
Optional.of(new ConsoleWriter(epochId, schema, options))
Optional.of(new ConsoleMicroBatchWriter(batchId, schema, options))
}

override def createContinuousWriter(
queryId: String,
schema: StructType,
mode: OutputMode,
options: DataSourceV2Options): Optional[ContinuousWriter] = {
Optional.of(new ConsoleContinuousWriter(schema, options))
}

def createRelation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,45 +20,85 @@ package org.apache.spark.sql.execution.streaming.sources
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.sources.v2.DataSourceV2Options
import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.types.StructType

/**
* A [[DataSourceV2Writer]] that collects results to the driver and prints them in the console.
* Generated by [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]].
*
* This sink should not be used for production, as it requires sending all rows to the driver
* and does not support recovery.
*/
class ConsoleWriter(batchId: Long, schema: StructType, options: DataSourceV2Options)
extends DataSourceV2Writer with Logging {
/** Common methods used to create writes for the the console sink */
trait ConsoleWriter extends Logging {

def options: DataSourceV2Options

// Number of rows to display, by default 20 rows
private val numRowsToShow = options.getInt("numRows", 20)
protected val numRowsToShow = options.getInt("numRows", 20)

// Truncate the displayed data if it is too long, by default it is true
private val isTruncated = options.getBoolean("truncate", true)
protected val isTruncated = options.getBoolean("truncate", true)

assert(SparkSession.getActiveSession.isDefined)
private val spark = SparkSession.getActiveSession.get
protected val spark = SparkSession.getActiveSession.get

def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory

override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory
def abort(messages: Array[WriterCommitMessage]): Unit = {}

override def commit(messages: Array[WriterCommitMessage]): Unit = synchronized {
val batch = messages.collect {
protected def printRows(
commitMessages: Array[WriterCommitMessage],
schema: StructType,
printMessage: String): Unit = {
val rows = commitMessages.collect {
case PackedRowCommitMessage(rows) => rows
}.flatten

// scalastyle:off println
println("-------------------------------------------")
println(s"Batch: $batchId")
println(printMessage)
println("-------------------------------------------")
// scalastyle:off println
spark.createDataFrame(
spark.sparkContext.parallelize(batch), schema)
spark
.createDataFrame(spark.sparkContext.parallelize(rows), schema)
.show(numRowsToShow, isTruncated)
}
}


/**
* A [[DataSourceV2Writer]] that collects results from a micro-batch query to the driver and
* prints them in the console. Created by
* [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]].
*
* This sink should not be used for production, as it requires sending all rows to the driver
* and does not support recovery.
*/
class ConsoleMicroBatchWriter(batchId: Long, schema: StructType, val options: DataSourceV2Options)
extends DataSourceV2Writer with ConsoleWriter {

override def commit(messages: Array[WriterCommitMessage]): Unit = {
printRows(messages, schema, s"Batch: $batchId")
}

override def toString(): String = {
s"ConsoleMicroBatchWriter[numRows=$numRowsToShow, truncate=$isTruncated]"
}
}

override def abort(messages: Array[WriterCommitMessage]): Unit = {}

override def toString(): String = s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]"
/**
* A [[DataSourceV2Writer]] that collects results from a continuous query to the driver and
* prints them in the console. Created by
* [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]].
*
* This sink should not be used for production, as it requires sending all rows to the driver
* and does not support recovery.
*/
class ConsoleContinuousWriter(schema: StructType, val options: DataSourceV2Options)
extends ContinuousWriter with ConsoleWriter {

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
printRows(messages, schema, s"Continuous processing epoch $epochId")
}

override def toString(): String = {
s"ConsoleContinuousWriter[numRows=$numRowsToShow, truncate=$isTruncated]"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ package org.apache.spark.sql.execution.streaming.sources

import java.io.ByteArrayOutputStream

import org.scalatest.time.SpanSugar._

import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.streaming.{StreamTest, Trigger}

class ConsoleWriterSuite extends StreamTest {
import testImplicits._

test("console") {
test("microbatch - default") {
val input = MemoryStream[Int]

val captured = new ByteArrayOutputStream()
Expand Down Expand Up @@ -77,7 +79,7 @@ class ConsoleWriterSuite extends StreamTest {
|""".stripMargin)
}

test("console with numRows") {
test("microbatch - with numRows") {
val input = MemoryStream[Int]

val captured = new ByteArrayOutputStream()
Expand Down Expand Up @@ -106,7 +108,7 @@ class ConsoleWriterSuite extends StreamTest {
|""".stripMargin)
}

test("console with truncation") {
test("microbatch - truncation") {
val input = MemoryStream[String]

val captured = new ByteArrayOutputStream()
Expand All @@ -132,4 +134,20 @@ class ConsoleWriterSuite extends StreamTest {
|
|""".stripMargin)
}

test("continuous - default") {
val captured = new ByteArrayOutputStream()
Console.withOut(captured) {
val input = spark.readStream
.format("rate")
.option("numPartitions", "1")
.option("rowsPerSecond", "5")
.load()
.select('value)

val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start()
assert(query.isActive)
query.stop()
}
}
}

0 comments on commit 6f69669

Please sign in to comment.