diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index a28a7d26b386..7b17bf61ff16 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -18,6 +18,7 @@ package org.apache.gluten import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.SQLConf import com.google.common.collect.ImmutableList @@ -34,9 +35,13 @@ case class GlutenNumaBindingInfo( totalCoreRange: Array[String] = null, numCoresPerExecutor: Int = -1) {} -class GlutenConfig(conf: SQLConf) extends Logging { +class GlutenConfig(sessionOpt: Option[SparkSession] = None) extends Logging { import GlutenConfig._ + def this(spark: SparkSession) = this(Some(spark)) + + def conf: SQLConf = sessionOpt.map(_.sessionState.conf).getOrElse(SQLConf.get) + def enableAnsiMode: Boolean = conf.ansiEnabled def enableGluten: Boolean = conf.getConf(GLUTEN_ENABLED) @@ -648,9 +653,7 @@ object GlutenConfig { var ins: GlutenConfig = _ - def getConf: GlutenConfig = { - new GlutenConfig(SQLConf.get) - } + def getConf: GlutenConfig = new GlutenConfig() @deprecated def getTempFile: String = synchronized {