diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 4a66d6508ae0a..cf05c6c989655 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -158,7 +158,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { logInfo(s"Initializing execution hive, version $hiveExecutionVersion") new ClientWrapper( version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), - config = newTemporaryConfiguration()) + config = newTemporaryConfiguration(), + initClassLoader = Utils.getContextOrSparkClassLoader) } SessionState.setCurrentSessionState(executionHive.state) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 982ed63874a5f..42c2d4c98ffb2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -54,10 +54,13 @@ import org.apache.spark.sql.execution.QueryExecutionException * @param version the version of hive used when pick function calls that are not compatible. * @param config a collection of configuration options that will be added to the hive conf before * opening the hive client. + * @param initClassLoader the classloader used when creating the `state` field of + * this ClientWrapper. */ private[hive] class ClientWrapper( version: HiveVersion, - config: Map[String, String]) + config: Map[String, String], + initClassLoader: ClassLoader) extends ClientInterface with Logging { @@ -98,11 +101,18 @@ private[hive] class ClientWrapper( // Create an internal session state for this ClientWrapper. val state = { val original = Thread.currentThread().getContextClassLoader - Thread.currentThread().setContextClassLoader(getClass.getClassLoader) + // Switch to the initClassLoader. + Thread.currentThread().setContextClassLoader(initClassLoader) val ret = try { val oldState = SessionState.get() if (oldState == null) { val initialConf = new HiveConf(classOf[SessionState]) + // HiveConf is a Hadoop Configuration, which has a field of classLoader and + // the initial value will be the current thread's context class loader + // (i.e. initClassLoader at here). + // We call initialConf.setClassLoader(initClassLoader) at here to make + // this action explicit. + initialConf.setClassLoader(initClassLoader) config.foreach { case (k, v) => logDebug(s"Hive Config: $k=$v") initialConf.set(k, v) @@ -125,6 +135,7 @@ private[hive] class ClientWrapper( def conf: HiveConf = SessionState.get().getConf // TODO: should be a def?s + // When we create this val client, the HiveConf of it (conf) is the one associated with state. private val client = Hive.get(conf) /** @@ -132,13 +143,9 @@ private[hive] class ClientWrapper( */ private def withHiveState[A](f: => A): A = synchronized { val original = Thread.currentThread().getContextClassLoader - // This setContextClassLoader is used for Hive 0.12's metastore since Hive 0.12 will not - // internally override the context class loader of the current thread with the class loader - // associated with the HiveConf in `state`. - Thread.currentThread().setContextClassLoader(getClass.getClassLoader) // Set the thread local metastore client to the client associated with this ClientWrapper. Hive.set(client) - // Starting from Hive 0.13.0, setCurrentSessionState will use the classLoader associated + // setCurrentSessionState will use the classLoader associated // with the HiveConf in `state` to override the context class loader of the current // thread. shim.setCurrentSessionState(state) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 40c167926c8d6..5ae2dbb50d86b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -43,6 +43,11 @@ import org.apache.hadoop.hive.ql.session.SessionState */ private[client] sealed abstract class Shim { + /** + * Set the current SessionState to the given SessionState. Also, set the context classloader of + * the current thread to the one set in the HiveConf of this given `state`. + * @param state + */ def setCurrentSessionState(state: SessionState): Unit /** @@ -159,7 +164,15 @@ private[client] class Shim_v0_12 extends Shim { JBoolean.TYPE, JBoolean.TYPE) - override def setCurrentSessionState(state: SessionState): Unit = startMethod.invoke(null, state) + override def setCurrentSessionState(state: SessionState): Unit = { + // Starting from Hive 0.13, setCurrentSessionState will internally override + // the context class loader of the current thread by the class loader set in + // the conf of the SessionState. So, for this Hive 0.12 shim, we add the same + // behavior and make shim.setCurrentSessionState of all Hive versions have the + // consistent behavior. + Thread.currentThread().setContextClassLoader(state.getConf.getClassLoader) + startMethod.invoke(null, state) + } override def getDataLocation(table: Table): Option[String] = Option(getDataLocationMethod.invoke(table)).map(_.toString()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 69cfc5c3c3216..0934ad5034671 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -95,9 +95,8 @@ private[hive] object IsolatedClientLoader { * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. - * @param rootClassLoader The system root classloader. - * @param baseClassLoader The spark classloader that is used to load shared classes. Must not know - * about Hive classes. + * @param rootClassLoader The system root classloader. Must not know about Hive classes. + * @param baseClassLoader The spark classloader that is used to load shared classes. */ private[hive] class IsolatedClientLoader( val version: HiveVersion, @@ -110,8 +109,8 @@ private[hive] class IsolatedClientLoader( val barrierPrefixes: Seq[String] = Seq.empty) extends Logging { - // Check to make sure that the base classloader does not know about Hive. - assert(Try(baseClassLoader.loadClass("org.apache.hive.HiveConf")).isFailure) + // Check to make sure that the root classloader does not know about Hive. + assert(Try(rootClassLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf")).isFailure) /** All jars used by the hive specific classloader. */ protected def allJars = execJars.toArray @@ -145,6 +144,7 @@ private[hive] class IsolatedClientLoader( def doLoadClass(name: String, resolve: Boolean): Class[_] = { val classFileName = name.replaceAll("\\.", "/") + ".class" if (isBarrierClass(name) && isolationOn) { + // For barrier classes, we construct a new copy of the class. val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName)) logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}") defineClass(name, bytes, 0, bytes.length) @@ -152,6 +152,7 @@ private[hive] class IsolatedClientLoader( logDebug(s"hive class: $name - ${getResource(classToPath(name))}") super.loadClass(name, resolve) } else { + // For shared classes, we delegate to baseClassLoader. logDebug(s"shared class: $name") baseClassLoader.loadClass(name) } @@ -167,7 +168,7 @@ private[hive] class IsolatedClientLoader( classLoader .loadClass(classOf[ClientWrapper].getName) .getConstructors.head - .newInstance(version, config) + .newInstance(version, config, classLoader) .asInstanceOf[ClientInterface] } catch { case e: InvocationTargetException => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala new file mode 100644 index 0000000000000..7963abf3b9c92 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import org.apache.spark._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + +/** + * This suite tests spark-submit with applications using HiveContext. + */ +class HiveSparkSubmitSuite + extends SparkFunSuite + with Matchers + with ResetSystemProperties + with Timeouts { + + def beforeAll() { + System.setProperty("spark.testing", "true") + } + + test("SPARK-8368: includes jars passed in through --jars") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") + val args = Seq( + "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), + "--name", "SparkSubmitClassLoaderTest", + "--master", "local-cluster[2,1,512]", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("SPARK-8020: set sql conf in spark conf") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,512]", + unusedJar.toString) + runSparkSubmit(args) + } + + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. + // This is copied from org.apache.spark.deploy.SparkSubmitSuite + private def runSparkSubmit(args: Seq[String]): Unit = { + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val process = Utils.executeCommand( + Seq("./bin/spark-submit") ++ args, + new File(sparkHome), + Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + try { + val exitCode = failAfter(120 seconds) { process.waitFor() } + if (exitCode != 0) { + fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + } + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } + } +} + +// This object is used for testing SPARK-8368: https://issues.apache.org/jira/browse/SPARK-8368. +// We test if we can load user jars in both driver and executors when HiveContext is used. +object SparkSubmitClassLoaderTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + val df = hiveContext.createDataFrame((1 to 100).map(i => (i, i))).toDF("i", "j") + // First, we load classes at driver side. + try { + Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) + Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + } catch { + case t: Throwable => + throw new Exception("Could not load user class from jar:\n", t) + } + // Second, we load classes at the executor side. + val result = df.mapPartitions { x => + var exception: String = null + try { + Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) + Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + } catch { + case t: Throwable => + exception = t + "\n" + t.getStackTraceString + exception = exception.replaceAll("\n", "\n\t") + } + Option(exception).toSeq.iterator + }.collect() + if (result.nonEmpty) { + throw new Exception("Could not load user class from jar:\n" + result(0)) + } + + // Load a Hive UDF from the jar. + hiveContext.sql( + """ + |CREATE TEMPORARY FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + """.stripMargin) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Load a Hive SerDe from the jar. + hiveContext.sql( + """ + |CREATE TABLE t1(key int, val string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' + """.stripMargin) + // Actually use the loaded UDF and SerDe. + hiveContext.sql( + "INSERT INTO TABLE t1 SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + val count = hiveContext.table("t1").orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"table t1 should have 10 rows instead of $count rows") + } + } +} + +// This object is used for testing SPARK-8020: https://issues.apache.org/jira/browse/SPARK-8020. +// We test if we can correctly set spark sql configurations when HiveContext is used. +object SparkSQLConfTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + // We override the SparkConf to add spark.sql.hive.metastore.version and + // spark.sql.hive.metastore.jars to the beginning of the conf entry array. + // So, if metadataHive get initialized after we set spark.sql.hive.metastore.version but + // before spark.sql.hive.metastore.jars get set, we will see the following exception: + // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only + // be used when hive execution version == hive metastore version. + // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars + // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1. + val conf = new SparkConf() { + override def getAll: Array[(String, String)] = { + def isMetastoreSetting(conf: String): Boolean = { + conf == "spark.sql.hive.metastore.version" || conf == "spark.sql.hive.metastore.jars" + } + // If there is any metastore settings, remove them. + val filteredSettings = super.getAll.filterNot(e => isMetastoreSetting(e._1)) + + // Always add these two metastore settings at the beginning. + ("spark.sql.hive.metastore.version" -> "0.12") +: + ("spark.sql.hive.metastore.jars" -> "maven") +: + filteredSettings + } + + // For this simple test, we do not really clone this object. + override def clone: SparkConf = this + } + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + // Run a simple command to make sure all lazy vals in hiveContext get instantiated. + hiveContext.tables().collect() + } +}