Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18817] [SPARKR] [SQL] Set default warehouse dir to tempdir #16290

Closed
wants to merge 9 commits into from
8 changes: 7 additions & 1 deletion R/pkg/R/sparkR.R
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ sparkR.session <- function(
...) {

sparkConfigMap <- convertNamedListToEnv(sparkConfig)

namedParams <- list(...)
if (length(namedParams) > 0) {
paramMap <- convertNamedListToEnv(namedParams)
Expand Down Expand Up @@ -400,11 +401,16 @@ sparkR.session <- function(
sparkConfigMap)
} else {
jsc <- get(".sparkRjsc", envir = .sparkREnv)
# NOTE(shivaram): Pass in a tempdir that is optionally used if the user has not
# overridden this. See SPARK-18817 for more details
warehouseTmpDir <- file.path(tempdir(), "spark-warehouse")

sparkSession <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"getOrCreateSparkSession",
jsc,
sparkConfigMap,
enableHiveSupport)
enableHiveSupport,
warehouseTmpDir)
assign(".sparkRsession", sparkSession, envir = .sparkREnv)
}
sparkSession
Expand Down
14 changes: 14 additions & 0 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -2890,6 +2890,20 @@ test_that("Collect on DataFrame when NAs exists at the top of a timestamp column
expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt"))
})

test_that("Default warehouse dir should be set to tempdir", {
setHiveContext(sc)

# Create a temporary database and a table in it
sql("CREATE DATABASE db1")
sql("USE db1")
sql("CREATE TABLE boxes (width INT, length INT, height INT)")
# spark-warehouse should be written only tempdir() and not current working directory
expect_true(file.exists(file.path(tempdir(), "spark-warehouse", "db1.db", "boxes")))
sql("DROP TABLE boxes")
sql("DROP DATABASE db1")
unsetHiveContext(sc)
})

unlink(parquetPath)
unlink(orcPath)
unlink(jsonPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.execution.command.ShowTablesCommand
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH
import org.apache.spark.sql.types._

private[sql] object SQLUtils extends Logging {
Expand All @@ -46,7 +47,17 @@ private[sql] object SQLUtils extends Logging {
def getOrCreateSparkSession(
jsc: JavaSparkContext,
sparkConfigMap: JMap[Object, Object],
enableHiveSupport: Boolean): SparkSession = {
enableHiveSupport: Boolean,
warehouseDir: String): SparkSession = {

// Check if SparkContext of sparkConfigMap contains spark.sql.warehouse.dir
// If not, set it to warehouseDir chosen by the R process.
// NOTE: We need to do this before creating the SparkSession.
val sqlWarehouseKey = WAREHOUSE_PATH.key
if (!jsc.sc.conf.contains(sqlWarehouseKey) && !sparkConfigMap.containsKey(sqlWarehouseKey)) {
jsc.sc.conf.set(sqlWarehouseKey, warehouseDir)
}

val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport
&& jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") {
SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,23 @@

package org.apache.spark.sql.api.r

import org.apache.spark.sql.test.SharedSQLContext
import java.util.HashMap

class SQLUtilsSuite extends SharedSQLContext {
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.test.SharedSQLContext

import testImplicits._
class SQLUtilsSuite extends SparkFunSuite {

test("dfToCols should collect and transpose a data frame") {
val sparkSession = SparkSession.builder()
.master("local")
.config("spark.ui.enabled", value = false)
.getOrCreate()

import sparkSession.implicits._

val df = Seq(
(1, 2, 3),
(4, 5, 6)
Expand All @@ -33,6 +43,19 @@ class SQLUtilsSuite extends SharedSQLContext {
Array(2, 5),
Array(3, 6)
))
sparkSession.stop()
}

test("warehouse path is set correctly by R constructor") {
SparkSession.clearDefaultSession()
val conf = new SparkConf().setAppName("test").setMaster("local")
val sparkContext2 = new SparkContext(conf)
val jsc = new JavaSparkContext(sparkContext2)
val warehouseDir = "/tmp/test-warehouse-dir"
val session = SQLUtils.getOrCreateSparkSession(
jsc, new HashMap[Object, Object], false, warehouseDir)
assert(session.sessionState.conf.warehousePath == warehouseDir)
session.stop()
SparkSession.clearDefaultSession()
}
}