Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23052][SS] Migrate ConsoleSink to data source V2 api. #20243

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,58 +17,36 @@

package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move this file into the sources subdirectory to make it consistent with other v2 sources?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fact this file can be merged into the ConsoleWriter.scala. The combined file will be named console.scala

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do this in a followup PR. It's not as simple as just moving it; we have to add an alias so that .format("org.apache.spark.sql.execution.streaming.ConsoleSinkProvider") continues to work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argh. okay. later then.

import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType

class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
// Number of rows to display, by default 20 rows
private val numRowsToShow = options.get("numRows").map(_.toInt).getOrElse(20)

// Truncate the displayed data if it is too long, by default it is true
private val isTruncated = options.get("truncate").map(_.toBoolean).getOrElse(true)
import java.util.Optional

// Track the batch id
private var lastBatchId = -1L

override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
val batchIdStr = if (batchId <= lastBatchId) {
s"Rerun batch: $batchId"
} else {
lastBatchId = batchId
s"Batch: $batchId"
}

// scalastyle:off println
println("-------------------------------------------")
println(batchIdStr)
println("-------------------------------------------")
// scalastyle:off println
data.sparkSession.createDataFrame(
data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
.show(numRowsToShow, isTruncated)
}
import scala.collection.JavaConverters._

override def toString(): String = s"ConsoleSink[numRows=$numRowsToShow, truncate=$isTruncated]"
}
import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options}
import org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport
import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType

case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame)
extends BaseRelation {
override def schema: StructType = data.schema
}

class ConsoleSinkProvider extends StreamSinkProvider
class ConsoleSinkProvider extends DataSourceV2
with MicroBatchWriteSupport
with DataSourceRegister
with CreatableRelationProvider {
def createSink(
sqlContext: SQLContext,
parameters: Map[String, String],
partitionColumns: Seq[String],
outputMode: OutputMode): Sink = {
new ConsoleSink(parameters)

override def createMicroBatchWriter(
queryId: String,
epochId: Long,
schema: StructType,
mode: OutputMode,
options: DataSourceV2Options): Optional[DataSourceV2Writer] = {
Optional.of(new ConsoleWriter(epochId, schema, options))
}

def createRelation(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is createRelation used for? For batch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume so. I'm not familiar with it, but it's not on the streaming source codepath.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.sources

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.sources.v2.DataSourceV2Options
import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.types.StructType

class ConsoleWriter(batchId: Long, schema: StructType, options: DataSourceV2Options)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docs and link it to the ConsoleSinkProvider since it's in a different file.

extends DataSourceV2Writer with Logging {
// Number of rows to display, by default 20 rows
private val numRowsToShow = options.getInt("numRows", 20)

// Truncate the displayed data if it is too long, by default it is true
private val isTruncated = options.getBoolean("truncate", true)

assert(SparkSession.getActiveSession.isDefined)
private val spark = SparkSession.getActiveSession.get

override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory

override def commit(messages: Array[WriterCommitMessage]): Unit = synchronized {
val batch = messages.collect {
case PackedRowCommitMessage(rows) => rows
}.fold(Array())(_ ++ _)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this complicated fold? Just array.collect { ... } returns an Array .. isnt it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It returns an array of arrays of rows, which isn't what we need.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use flatten instead of fold. Much cleaner.


// scalastyle:off println
println("-------------------------------------------")
println(s"Batch: $batchId")
println("-------------------------------------------")
// scalastyle:off println
spark.createDataFrame(
spark.sparkContext.parallelize(batch), schema)
.show(numRowsToShow, isTruncated)
}

override def abort(messages: Array[WriterCommitMessage]): Unit = {}

override def toString(): String = s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.sources

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}

/**
* A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery
* to a [[org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer]] on the driver.
*/
case object PackedRowWriterFactory extends DataWriterFactory[Row] {
def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = {
new PackedRowDataWriter()
}
}

case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docs.


class PackedRowDataWriter() extends DataWriter[Row] with Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docs.

private val data = mutable.Buffer[Row]()

override def write(row: Row): Unit = data.append(row)

override def commit(): PackedRowCommitMessage = {
val msg = PackedRowCommitMessage(data.clone().toArray)
Copy link
Contributor

@tdas tdas Jan 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you cloning and then calling toArray? Just data.toArray will create an immutable copy.

data.clear()
msg
}

override def abort(): Unit = data.clear()
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2}
import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport}

/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
Expand Down Expand Up @@ -279,18 +280,26 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
useTempCheckpointLocation = true,
trigger = trigger)
} else {
val dataSource =
DataSource(
df.sparkSession,
className = source,
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))
val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are checking for the same conditions here as well as in the StreamingQueryManager.createQuery. I think we need to refactor this, probably sometime in the future once we get rid of v1 completely.

Either way, we should immediately add a general test suite (say StreamingDataSourceV2Suite) that tests these cases with various fake data sources.

val sink = (ds.newInstance(), trigger) match {
case (w: ContinuousWriteSupport, _: ContinuousTrigger) => w
case (_, _: ContinuousTrigger) => throw new AnalysisException(
s"Data source $source does not support continuous writing")
case (w: MicroBatchWriteSupport, _) => w
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isnt there a case where it does not have MicroBatchWriteSupport, but the trigger is ProcessingTime/OneTime? That should have a different error message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, we have to just fall back to the V1 path, because V1 sinks don't have MicroBatchWriteSupport.

case _ =>
val ds = DataSource(
df.sparkSession,
className = source,
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))
ds.createSink(outputMode)
}
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
extraOptions.toMap,
dataSource.createSink(outputMode),
sink,
outputMode,
useTempCheckpointLocation = source == "console",
recoverFromCheckpointLocation = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ org.apache.spark.sql.sources.FakeSourceFour
org.apache.fakesource.FakeExternalSourceOne
org.apache.fakesource.FakeExternalSourceTwo
org.apache.fakesource.FakeExternalSourceThree
org.apache.spark.sql.streaming.sources.FakeStreamingMicroBatchOnly
org.apache.spark.sql.streaming.sources.FakeStreamingContinuousOnly
org.apache.spark.sql.streaming.sources.FakeStreamingBothModes
org.apache.spark.sql.streaming.sources.FakeStreamingNeitherMode
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.sources
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


import java.io.ByteArrayOutputStream

import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.StreamTest

class ConsoleWriterSuite extends StreamTest {
import testImplicits._

test("console") {
val input = MemoryStream[Int]

val captured = new ByteArrayOutputStream()
Console.withOut(captured) {
val query = input.toDF().writeStream.format("console").start()
try {
input.addData(1, 2, 3)
query.processAllAvailable()
input.addData(4, 5, 6)
query.processAllAvailable()
input.addData()
query.processAllAvailable()
} finally {
query.stop()
}
}

assert(captured.toString() ==
"""-------------------------------------------
|Batch: 0
|-------------------------------------------
|+-----+
||value|
|+-----+
|| 1|
|| 2|
|| 3|
|+-----+
|
|-------------------------------------------
|Batch: 1
|-------------------------------------------
|+-----+
||value|
|+-----+
|| 4|
|| 5|
|| 6|
|+-----+
|
|-------------------------------------------
|Batch: 2
|-------------------------------------------
|+-----+
||value|
|+-----+
|+-----+
|
|""".stripMargin)
}

test("console with numRows") {
val input = MemoryStream[Int]

val captured = new ByteArrayOutputStream()
Console.withOut(captured) {
val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start()
try {
input.addData(1, 2, 3)
query.processAllAvailable()
} finally {
query.stop()
}
}

assert(captured.toString() ==
"""-------------------------------------------
|Batch: 0
|-------------------------------------------
|+-----+
||value|
|+-----+
|| 1|
|| 2|
|+-----+
|only showing top 2 rows
|
|""".stripMargin)
}

test("console with truncation") {
val input = MemoryStream[String]

val captured = new ByteArrayOutputStream()
Console.withOut(captured) {
val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start()
try {
input.addData("123456789012345678901234567890")
query.processAllAvailable()
} finally {
query.stop()
}
}

assert(captured.toString() ==
"""-------------------------------------------
|Batch: 0
|-------------------------------------------
|+--------------------+
|| value|
|+--------------------+
||12345678901234567...|
|+--------------------+
|
|""".stripMargin)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a test to check numrows, something like this:

  test("console with numRows") {
    val input = MemoryStream[Int]

    val captured = new ByteArrayOutputStream()
    Console.withOut(captured) {
      val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start()
      try {
        input.addData(1, 2, 3)
        query.processAllAvailable()
      } finally {
        query.stop()
      }
    }

    assert(captured.toString() ==
      """-------------------------------------------
        |Batch: 0
        |-------------------------------------------
        |+-----+
        ||value|
        |+-----+
        ||    1|
        ||    2|
        |+-----+
        |only showing top 2 rows
        |
        |""".stripMargin)
  }

  test("console with truncation") {
    val input = MemoryStream[String]

    val captured = new ByteArrayOutputStream()
    Console.withOut(captured) {
      val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start()
      try {
        input.addData("123456789012345678901234567890")
        query.processAllAvailable()
      } finally {
        query.stop()
      }
    }

    assert(captured.toString() ==
      """-------------------------------------------
        |Batch: 0
        |-------------------------------------------
        |+--------------------+
        ||               value|
        |+--------------------+
        ||12345678901234567...|
        |+--------------------+
        |
        |""".stripMargin)
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed we could. Thanks for writing out the tests!

Loading