Skip to content

Commit

Permalink
support accessing SQLConf inside tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed May 11, 2018
1 parent e39b7d0 commit 7c7caf8
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 45 deletions.
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,6 @@ private[spark] class TaskContextImpl(

private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException

// TODO: shall we publish it and define it in `TaskContext`?
private[spark] def getLocalProperties(): Properties = localProperties
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,4 @@ private[sql] object CreateJacksonParser extends Serializable {
def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = {
jsonFactory.createParser(new InputStreamReader(is, enc))
}

def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
val ba = row.getBinary(0)

jsonFactory.createParser(ba, 0, ba.length)
}

def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
val binary = row.getBinary(0)
val sd = getStreamDecoder(enc, binary, binary.length)

jsonFactory.createParser(sd)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.internal

import java.util.{Map => JMap}

import org.apache.spark.{TaskContext, TaskContextImpl}
import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader}

/**
* A readonly SQLConf that will be created by tasks running at the executor side. It reads the
* configs from the local properties which are propagated from driver to executors.
*/
class ReadOnlySQLConf(context: TaskContext) extends SQLConf {

@transient override val settings: JMap[String, String] = {
context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]]
}

@transient override protected val reader: ConfigReader = {
new ConfigReader(new TaskContextConfigProvider(context))
}

override protected def setConfWithCheck(key: String, value: String): Unit = {
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
}

override def unsetConf(key: String): Unit = {
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
}

override def unsetConf(entry: ConfigEntry[_]): Unit = {
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
}

override def clear(): Unit = {
throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
}
}

class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider {
override def get(key: String): Option[String] = Option(context.getLocalProperty(key))
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ import scala.util.matching.Regex

import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
import org.apache.spark.util.Utils

////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines the configuration options for Spark SQL.
Expand Down Expand Up @@ -107,7 +106,13 @@ object SQLConf {
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
* run unit tests (that does not involve SparkSession) in serial order.
*/
def get: SQLConf = confGetter.get()()
def get: SQLConf = {
if (TaskContext.get != null) {
new ReadOnlySQLConf(TaskContext.get())
} else {
confGetter.get()()
}
}

val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
.internal()
Expand Down Expand Up @@ -1274,17 +1279,11 @@ object SQLConf {
class SQLConf extends Serializable with Logging {
import SQLConf._

if (Utils.isTesting && SparkEnv.get != null) {
// assert that we're only accessing it on the driver.
assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
"SQLConf should only be created and accessed on the driver.")
}

/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
new java.util.HashMap[String, String]())

@transient private val reader = new ConfigReader(settings)
@transient protected val reader = new ConfigReader(settings)

/** ************************ Spark SQL Params/Hints ******************* */

Expand Down Expand Up @@ -1734,7 +1733,7 @@ class SQLConf extends Serializable with Logging {
settings.containsKey(key)
}

private def setConfWithCheck(key: String, value: String): Unit = {
protected def setConfWithCheck(key: String, value: String): Unit = {
settings.put(key, value)
}

Expand Down
4 changes: 3 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1607,7 +1607,9 @@ class Dataset[T] private[sql](
*/
@Experimental
@InterfaceStability.Evolving
def reduce(func: (T, T) => T): T = rdd.reduce(func)
def reduce(func: (T, T) => T): T = withNewRDDExecutionId {
rdd.reduce(func)
}

/**
* :: Experimental ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,27 @@ object SQLExecution {
// sparkContext.getCallSite() would first try to pick up any call site that was previously
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
// streaming queries would give us call site like "run at <unknown>:0"
val callSite = sparkSession.sparkContext.getCallSite()
val callSite = sc.getCallSite()

sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
// Set all the specified SQL configs to local properties, so that they can be available at
// the executor side.
val allConfigs = sparkSession.sessionState.conf.getAllConfs
allConfigs.foreach {
// Excludes external configs defined by users.
case (key, value) if key.startsWith("spark") => sc.setLocalProperty(key, value)
}

sc.listenerBus.post(SparkListenerSQLExecutionStart(
executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
try {
body
} finally {
sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
sc.listenerBus.post(SparkListenerSQLExecutionEnd(
executionId, System.currentTimeMillis()))
allConfigs.foreach {
case (key, _) => sc.setLocalProperty(key, null)
}
}
} finally {
executionIdToQueryExecution.remove(executionId)
Expand All @@ -90,12 +101,23 @@ object SQLExecution {
* thread from the original one, this method can be used to connect the Spark jobs in this action
* with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`.
*/
def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = {
def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = {
val sc = sparkSession.sparkContext
val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
// Set all the specified SQL configs to local properties, so that they can be available at
// the executor side.
val allConfigs = sparkSession.sessionState.conf.getAllConfs
allConfigs.foreach {
// Excludes external configs defined by users.
case (key, value) if key.startsWith("spark") => sc.setLocalProperty(key, value)
}
try {
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
body
} finally {
allConfigs.foreach {
case (key, _) => sc.setLocalProperty(key, null)
}
sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sparkContext, executionId) {
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to convert data to Scala types
val rows: Array[InternalRow] = child.executeCollect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,7 @@ object TextInputJsonDataSource extends JsonDataSource {

def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = {
val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions)
val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd
val rowParser = parsedOptions.encoding.map { enc =>
CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow)
}.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow))

JsonInferSchema.infer(rdd, parsedOptions, rowParser)
JsonInferSchema.infer(sampled, parsedOptions, CreateJacksonParser.string)
}

private def createBaseDataset(
Expand Down Expand Up @@ -165,7 +160,8 @@ object MultiLineJsonDataSource extends JsonDataSource {
.map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream))
.getOrElse(createParser(_: JsonFactory, _: PortableDataStream))

JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
JsonInferSchema.infer[PortableDataStream](
sparkSession.createDataset(sampled)(Encoders.javaSerialization), parsedOptions, parser)
}

private def createBaseRdd(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.Comparator
import com.fasterxml.jackson.core._

import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Encoders}
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
import org.apache.spark.sql.catalyst.json.JSONOptions
Expand All @@ -39,14 +39,14 @@ private[sql] object JsonInferSchema {
* 3. Replace any remaining null fields with string, the top type
*/
def infer[T](
json: RDD[T],
json: Dataset[T],
configOptions: JSONOptions,
createParser: (JsonFactory, T) => JsonParser): StructType = {
val parseMode = configOptions.parseMode
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord

// perform schema inference on each row and merge afterwards
val rootType = json.mapPartitions { iter =>
val inferredTypes = json.mapPartitions { iter =>
val factory = new JsonFactory()
configOptions.setJacksonOptions(factory)
iter.flatMap { row =>
Expand All @@ -67,8 +67,15 @@ private[sql] object JsonInferSchema {
}
}
}
}.fold(StructType(Nil))(
compatibleRootType(columnNameOfCorruptRecord, parseMode))
}(Encoders.javaSerialization)

// TODO: use `Dataset.fold` once we have it.
val rootType = try {
inferredTypes.reduce(compatibleRootType(columnNameOfCorruptRecord, parseMode))
} catch {
case e: UnsupportedOperationException if e.getMessage == "empty collection" =>
StructType(Nil)
}

canonicalizeType(rootType) match {
case Some(st: StructType) => st
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ case class BroadcastExchangeExec(
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sparkContext, executionId) {
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
try {
val beforeCollect = System.nanoTime()
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
test("SPARK-6245 JsonInferSchema.infer on empty RDD") {
// This is really a test that it doesn't throw an exception
val emptySchema = JsonInferSchema.infer(
empty.rdd,
empty,
new JSONOptions(Map.empty[String, String], "GMT"),
CreateJacksonParser.string)
assert(StructType(Seq()) === emptySchema)
Expand All @@ -1401,7 +1401,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {

test("SPARK-8093 Erase empty structs") {
val emptySchema = JsonInferSchema.infer(
emptyRecords.rdd,
emptyRecords,
new JSONOptions(Map.empty[String, String], "GMT"),
CreateJacksonParser.string)
assert(StructType(Seq()) === emptySchema)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.internal

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.test.SQLTestUtils

class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils {
import testImplicits._

protected var spark: SparkSession = null

// Create a new [[SparkSession]] running in local-cluster mode.
override def beforeAll(): Unit = {
super.beforeAll()
spark = SparkSession.builder()
.master("local-cluster[2,1,1024]")
.appName("testing")
.getOrCreate()
}

override def afterAll(): Unit = {
spark.stop()
spark = null
}

test("ReadonlySQLConf is correctly created at the executor side") {
SQLConf.get.setConfString("spark.sql.x", "a")
try {
val checks = spark.range(10).mapPartitions { it =>
val conf = SQLConf.get
Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a")
}.collect()
assert(checks.forall(_ == true))
} finally {
SQLConf.get.unsetConf("spark.sql.x")
}
}

test("case-sensitive config should work for json schema inference") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
withTempPath { path =>
val pathString = path.getCanonicalPath
spark.range(10).select('id.as("ID")).write.json(pathString)
spark.range(10).write.mode("append").json(pathString)
assert(spark.read.json(pathString).columns.toSet == Set("id", "ID"))
}
}
}
}

0 comments on commit 7c7caf8

Please sign in to comment.