Skip to content

Commit

Permalink
[SPARK-15593][SQL] Add DataFrameWriter.foreach to allow the user cons…
Browse files Browse the repository at this point in the history
…uming data in ContinuousQuery

## What changes were proposed in this pull request?

* Add DataFrameWriter.foreach to allow the user consuming data in ContinuousQuery
  * ForeachWriter is the interface for the user to consume partitions of data
* Add a type parameter T to DataFrameWriter

Usage
```Scala
val ds = spark.read....stream().as[String]
ds.....write
         .queryName(...)
        .option("checkpointLocation", ...)
        .foreach(new ForeachWriter[Int] {
          def open(partitionId: Long, version: Long): Boolean = {
             // prepare some resources for a partition
             // check `version` if possible and return `false` if this is a duplicated data to skip the data processing.
          }

          override def process(value: Int): Unit = {
              // process data
          }

          def close(errorOrNull: Throwable): Unit = {
             // release resources for a partition
             // check `errorOrNull` and handle the error if necessary.
          }
        })
```

## How was this patch tested?

New unit tests.

Author: Shixiong Zhu <[email protected]>

Closes apache#13342 from zsxwing/foreach.
  • Loading branch information
zsxwing authored and tdas committed Jun 10, 2016
1 parent 5a3533e commit 00c3101
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 42 deletions.
150 changes: 110 additions & 40 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource, HadoopFsRelation}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{ContinuousQuery, OutputMode, ProcessingTime, Trigger}
import org.apache.spark.util.Utils
Expand All @@ -40,7 +40,9 @@ import org.apache.spark.util.Utils
*
* @since 1.4.0
*/
final class DataFrameWriter private[sql](df: DataFrame) {
final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

private val df = ds.toDF()

/**
* Specifies the behavior when data or table already exists. Options include:
Expand All @@ -51,7 +53,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
def mode(saveMode: SaveMode): DataFrameWriter = {
def mode(saveMode: SaveMode): DataFrameWriter[T] = {
// mode() is used for non-continuous queries
// outputMode() is used for continuous queries
assertNotStreaming("mode() can only be called on non-continuous queries")
Expand All @@ -68,7 +70,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
def mode(saveMode: String): DataFrameWriter = {
def mode(saveMode: String): DataFrameWriter[T] = {
// mode() is used for non-continuous queries
// outputMode() is used for continuous queries
assertNotStreaming("mode() can only be called on non-continuous queries")
Expand All @@ -93,7 +95,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
@Experimental
def outputMode(outputMode: OutputMode): DataFrameWriter = {
def outputMode(outputMode: OutputMode): DataFrameWriter[T] = {
assertStreaming("outputMode() can only be called on continuous queries")
this.outputMode = outputMode
this
Expand All @@ -109,7 +111,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
@Experimental
def outputMode(outputMode: String): DataFrameWriter = {
def outputMode(outputMode: String): DataFrameWriter[T] = {
assertStreaming("outputMode() can only be called on continuous queries")
this.outputMode = outputMode.toLowerCase match {
case "append" =>
Expand Down Expand Up @@ -147,7 +149,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
@Experimental
def trigger(trigger: Trigger): DataFrameWriter = {
def trigger(trigger: Trigger): DataFrameWriter[T] = {
assertStreaming("trigger() can only be called on continuous queries")
this.trigger = trigger
this
Expand All @@ -158,7 +160,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
def format(source: String): DataFrameWriter = {
def format(source: String): DataFrameWriter[T] = {
this.source = source
this
}
Expand All @@ -168,7 +170,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
def option(key: String, value: String): DataFrameWriter = {
def option(key: String, value: String): DataFrameWriter[T] = {
this.extraOptions += (key -> value)
this
}
Expand All @@ -178,28 +180,28 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 2.0.0
*/
def option(key: String, value: Boolean): DataFrameWriter = option(key, value.toString)
def option(key: String, value: Boolean): DataFrameWriter[T] = option(key, value.toString)

/**
* Adds an output option for the underlying data source.
*
* @since 2.0.0
*/
def option(key: String, value: Long): DataFrameWriter = option(key, value.toString)
def option(key: String, value: Long): DataFrameWriter[T] = option(key, value.toString)

/**
* Adds an output option for the underlying data source.
*
* @since 2.0.0
*/
def option(key: String, value: Double): DataFrameWriter = option(key, value.toString)
def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString)

/**
* (Scala-specific) Adds output options for the underlying data source.
*
* @since 1.4.0
*/
def options(options: scala.collection.Map[String, String]): DataFrameWriter = {
def options(options: scala.collection.Map[String, String]): DataFrameWriter[T] = {
this.extraOptions ++= options
this
}
Expand All @@ -209,7 +211,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
def options(options: java.util.Map[String, String]): DataFrameWriter = {
def options(options: java.util.Map[String, String]): DataFrameWriter[T] = {
this.options(options.asScala)
this
}
Expand All @@ -232,7 +234,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
@scala.annotation.varargs
def partitionBy(colNames: String*): DataFrameWriter = {
def partitionBy(colNames: String*): DataFrameWriter[T] = {
this.partitioningColumns = Option(colNames)
this
}
Expand All @@ -246,7 +248,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0
*/
@scala.annotation.varargs
def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = {
def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter[T] = {
this.numBuckets = Option(numBuckets)
this.bucketColumnNames = Option(colName +: colNames)
this
Expand All @@ -260,7 +262,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0
*/
@scala.annotation.varargs
def sortBy(colName: String, colNames: String*): DataFrameWriter = {
def sortBy(colName: String, colNames: String*): DataFrameWriter[T] = {
this.sortColumnNames = Option(colName +: colNames)
this
}
Expand Down Expand Up @@ -301,7 +303,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
@Experimental
def queryName(queryName: String): DataFrameWriter = {
def queryName(queryName: String): DataFrameWriter[T] = {
assertStreaming("queryName() can only be called on continuous queries")
this.extraOptions += ("queryName" -> queryName)
this
Expand Down Expand Up @@ -337,16 +339,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
val queryName =
extraOptions.getOrElse(
"queryName", throw new AnalysisException("queryName must be specified for memory sink"))
val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified =>
new Path(userSpecified).toUri.toString
}.orElse {
val checkpointConfig: Option[String] =
df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION)

checkpointConfig.map { location =>
new Path(location, queryName).toUri.toString
}
}.getOrElse {
val checkpointLocation = getCheckpointLocation(queryName, failIfNotSet = false).getOrElse {
Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath
}

Expand Down Expand Up @@ -378,28 +371,105 @@ final class DataFrameWriter private[sql](df: DataFrame) {
className = source,
options = extraOptions.toMap,
partitionColumns = normalizedParCols.getOrElse(Nil))

val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
val checkpointLocation = extraOptions.get("checkpointLocation")
.orElse {
df.sparkSession.sessionState.conf.checkpointLocation.map { l =>
new Path(l, queryName).toUri.toString
}
}.getOrElse {
throw new AnalysisException("checkpointLocation must be specified either " +
"through option() or SQLConf")
}

df.sparkSession.sessionState.continuousQueryManager.startQuery(
queryName,
checkpointLocation,
getCheckpointLocation(queryName, failIfNotSet = true).get,
df,
dataSource.createSink(outputMode),
outputMode,
trigger)
}
}

/**
* :: Experimental ::
* Starts the execution of the streaming query, which will continually send results to the given
* [[ForeachWriter]] as as new data arrives. The [[ForeachWriter]] can be used to send the data
* generated by the [[DataFrame]]/[[Dataset]] to an external system. The returned The returned
* [[ContinuousQuery]] object can be used to interact with the stream.
*
* Scala example:
* {{{
* datasetOfString.write.foreach(new ForeachWriter[String] {
*
* def open(partitionId: Long, version: Long): Boolean = {
* // open connection
* }
*
* def process(record: String) = {
* // write string to connection
* }
*
* def close(errorOrNull: Throwable): Unit = {
* // close the connection
* }
* })
* }}}
*
* Java example:
* {{{
* datasetOfString.write().foreach(new ForeachWriter<String>() {
*
* @Override
* public boolean open(long partitionId, long version) {
* // open connection
* }
*
* @Override
* public void process(String value) {
* // write string to connection
* }
*
* @Override
* public void close(Throwable errorOrNull) {
* // close the connection
* }
* });
* }}}
*
* @since 2.0.0
*/
@Experimental
def foreach(writer: ForeachWriter[T]): ContinuousQuery = {
assertNotBucketed("foreach")
assertStreaming(
"foreach() can only be called on streaming Datasets/DataFrames.")

val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
val sink = new ForeachSink[T](ds.sparkSession.sparkContext.clean(writer))(ds.exprEnc)
df.sparkSession.sessionState.continuousQueryManager.startQuery(
queryName,
getCheckpointLocation(queryName, failIfNotSet = false).getOrElse {
Utils.createTempDir(namePrefix = "foreach.stream").getCanonicalPath
},
df,
sink,
outputMode,
trigger)
}

/**
* Returns the checkpointLocation for a query. If `failIfNotSet` is `true` but the checkpoint
* location is not set, [[AnalysisException]] will be thrown. If `failIfNotSet` is `false`, `None`
* will be returned if the checkpoint location is not set.
*/
private def getCheckpointLocation(queryName: String, failIfNotSet: Boolean): Option[String] = {
val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified =>
new Path(userSpecified).toUri.toString
}.orElse {
df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION).map { location =>
new Path(location, queryName).toUri.toString
}
}
if (failIfNotSet && checkpointLocation.isEmpty) {
throw new AnalysisException("checkpointLocation must be specified either " +
"""through option("checkpointLocation", ...) or """ +
s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)""")
}
checkpointLocation
}

/**
* Inserts the content of the [[DataFrame]] to the specified table. It requires that
* the schema of the [[DataFrame]] is the same as the schema of the table.
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2400,7 +2400,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
def write: DataFrameWriter = new DataFrameWriter(toDF())
def write: DataFrameWriter[T] = new DataFrameWriter[T](this)

/**
* Returns the content of the Dataset as a Dataset of JSON strings.
Expand Down
Loading

0 comments on commit 00c3101

Please sign in to comment.