diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 94820376ff7e7..f2aa3259731d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional -import scala.collection.JavaConverters._ - import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.execution.streaming.sources.{ConsoleContinuousWriter, ConsoleMicroBatchWriter, ConsoleWriter} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -37,16 +36,25 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) class ConsoleSinkProvider extends DataSourceV2 with MicroBatchWriteSupport + with ContinuousWriteSupport with DataSourceRegister with CreatableRelationProvider { override def createMicroBatchWriter( queryId: String, - epochId: Long, + batchId: Long, schema: StructType, mode: OutputMode, options: DataSourceV2Options): Optional[DataSourceV2Writer] = { - Optional.of(new ConsoleWriter(epochId, schema, options)) + Optional.of(new ConsoleMicroBatchWriter(batchId, schema, options)) + } + + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + Optional.of(new ConsoleContinuousWriter(schema, options)) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index 361979984bbec..6fb61dff60045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -20,45 +20,85 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType -/** - * A [[DataSourceV2Writer]] that collects results to the driver and prints them in the console. - * Generated by [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. - * - * This sink should not be used for production, as it requires sending all rows to the driver - * and does not support recovery. - */ -class ConsoleWriter(batchId: Long, schema: StructType, options: DataSourceV2Options) - extends DataSourceV2Writer with Logging { +/** Common methods used to create writes for the the console sink */ +trait ConsoleWriter extends Logging { + + def options: DataSourceV2Options + // Number of rows to display, by default 20 rows - private val numRowsToShow = options.getInt("numRows", 20) + protected val numRowsToShow = options.getInt("numRows", 20) // Truncate the displayed data if it is too long, by default it is true - private val isTruncated = options.getBoolean("truncate", true) + protected val isTruncated = options.getBoolean("truncate", true) assert(SparkSession.getActiveSession.isDefined) - private val spark = SparkSession.getActiveSession.get + protected val spark = SparkSession.getActiveSession.get + + def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory - override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + def abort(messages: Array[WriterCommitMessage]): Unit = {} - override def commit(messages: Array[WriterCommitMessage]): Unit = synchronized { - val batch = messages.collect { + protected def printRows( + commitMessages: Array[WriterCommitMessage], + schema: StructType, + printMessage: String): Unit = { + val rows = commitMessages.collect { case PackedRowCommitMessage(rows) => rows }.flatten // scalastyle:off println println("-------------------------------------------") - println(s"Batch: $batchId") + println(printMessage) println("-------------------------------------------") // scalastyle:off println - spark.createDataFrame( - spark.sparkContext.parallelize(batch), schema) + spark + .createDataFrame(spark.sparkContext.parallelize(rows), schema) .show(numRowsToShow, isTruncated) } +} + + +/** + * A [[DataSourceV2Writer]] that collects results from a micro-batch query to the driver and + * prints them in the console. Created by + * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleMicroBatchWriter(batchId: Long, schema: StructType, val options: DataSourceV2Options) + extends DataSourceV2Writer with ConsoleWriter { + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + printRows(messages, schema, s"Batch: $batchId") + } + + override def toString(): String = { + s"ConsoleMicroBatchWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } +} - override def abort(messages: Array[WriterCommitMessage]): Unit = {} - override def toString(): String = s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" +/** + * A [[DataSourceV2Writer]] that collects results from a continuous query to the driver and + * prints them in the console. Created by + * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleContinuousWriter(schema: StructType, val options: DataSourceV2Options) + extends ContinuousWriter with ConsoleWriter { + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + printRows(messages, schema, s"Continuous processing epoch $epochId") + } + + override def toString(): String = { + s"ConsoleContinuousWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala index 60ffee9b9b42c..55acf2ba28d2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.ByteArrayOutputStream +import org.scalatest.time.SpanSugar._ + import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{StreamTest, Trigger} class ConsoleWriterSuite extends StreamTest { import testImplicits._ - test("console") { + test("microbatch - default") { val input = MemoryStream[Int] val captured = new ByteArrayOutputStream() @@ -77,7 +79,7 @@ class ConsoleWriterSuite extends StreamTest { |""".stripMargin) } - test("console with numRows") { + test("microbatch - with numRows") { val input = MemoryStream[Int] val captured = new ByteArrayOutputStream() @@ -106,7 +108,7 @@ class ConsoleWriterSuite extends StreamTest { |""".stripMargin) } - test("console with truncation") { + test("microbatch - truncation") { val input = MemoryStream[String] val captured = new ByteArrayOutputStream() @@ -132,4 +134,20 @@ class ConsoleWriterSuite extends StreamTest { | |""".stripMargin) } + + test("continuous - default") { + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val input = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "5") + .load() + .select('value) + + val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start() + assert(query.isActive) + query.stop() + } + } }