Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AtomicCreateTableAsSelect with Delta Lake [databricks] #9425

Merged
merged 3 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.databricks.sql.transaction.tahoe.rapids

import com.databricks.sql.transaction.tahoe.{DeltaLog, OptimisticTransaction}
import com.nvidia.spark.rapids.RapidsConf
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.util.Clock
Expand Down Expand Up @@ -70,4 +71,12 @@ object GpuDeltaLog {
val deltaLog = DeltaLog.forTable(spark, dataPath, options)
new GpuDeltaLog(deltaLog, rapidsConf)
}

def forTable(
spark: SparkSession,
tableLocation: Path,
rapidsConf: RapidsConf): GpuDeltaLog = {
val deltaLog = DeltaLog.forTable(spark, tableLocation)
new GpuDeltaLog(deltaLog, rapidsConf)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,31 @@

package com.nvidia.spark.rapids.delta

import scala.collection.JavaConverters.mapAsScalaMapConverter

import com.databricks.sql.managedcatalog.UnityCatalogV2Proxy
import com.databricks.sql.transaction.tahoe.{DeltaLog, DeltaParquetFileFormat}
import com.databricks.sql.transaction.tahoe.catalog.DeltaCatalog
import com.databricks.sql.transaction.tahoe.commands.{DeleteCommand, DeleteCommandEdge, MergeIntoCommand, MergeIntoCommandEdge, UpdateCommand, UpdateCommandEdge}
import com.databricks.sql.transaction.tahoe.sources.DeltaDataSource
import com.databricks.sql.transaction.tahoe.rapids.GpuDeltaCatalog
import com.databricks.sql.transaction.tahoe.sources.{DeltaDataSource, DeltaSourceUtils}
import com.nvidia.spark.rapids._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.StagingTableCatalog
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.{FileFormat, SaveIntoDataSourceCommand}
import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.ExternalSource
import org.apache.spark.sql.sources.CreatableRelationProvider

/**
* Implements the DeltaProvider interface for Databricks Delta Lake.
* Common implementation of the DeltaProvider interface for all Databricks versions.
*/
object DeltaProviderImpl extends DeltaProviderImplBase {
object DatabricksDeltaProvider extends DeltaProviderImplBase {
override def getCreatableRelationRules: Map[Class[_ <: CreatableRelationProvider],
CreatableRelationProviderRule[_ <: CreatableRelationProvider]] = {
Seq(
Expand Down Expand Up @@ -92,6 +101,43 @@ object DeltaProviderImpl extends DeltaProviderImplBase {
val cpuFormat = format.asInstanceOf[DeltaParquetFileFormat]
GpuDeltaParquetFileFormat.convertToGpu(cpuFormat)
}

override def isSupportedCatalog(catalogClass: Class[_ <: StagingTableCatalog]): Boolean = {
catalogClass == classOf[DeltaCatalog] || catalogClass == classOf[UnityCatalogV2Proxy]
Copy link
Collaborator

@gerashegalov gerashegalov Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: are we guaranteed in this comparison that both catalogClass and classOf[DeltaCatalog] are loaded via the same ClassLoader instance, might get a false negative otherwise.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so. The technique essentially matches the technique used for almost all overrides, as GpuOverrides is building a large hashmap with classes as keys and doing a lookup on that map based on what we find in the CPU plan. So I would think that if we have an issue with this code, we'd have a similar issue with how all overrides work in the plugin.

}

override def tagForGpu(
cpuExec: AtomicCreateTableAsSelectExec,
meta: AtomicCreateTableAsSelectExecMeta): Unit = {
require(isSupportedCatalog(cpuExec.catalog.getClass))
if (!meta.conf.isDeltaWriteEnabled) {
meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " +
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
val properties = cpuExec.properties
val provider = properties.getOrElse("provider",
cpuExec.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME))
if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) {
meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider")
}
RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None,
cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session)
}

override def convertToGpu(
cpuExec: AtomicCreateTableAsSelectExec,
meta: AtomicCreateTableAsSelectExecMeta): GpuExec = {
GpuAtomicCreateTableAsSelectExec(
cpuExec.output,
new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
cpuExec.ident,
cpuExec.partitioning,
cpuExec.plan,
meta.childPlans.head.convertIfNeeded(),
cpuExec.tableSpec,
cpuExec.writeOptions,
cpuExec.ifNotExists)
}
}

class DeltaCreatableRelationProviderMeta(
Expand All @@ -115,8 +161,8 @@ class DeltaCreatableRelationProviderMeta(
val path = saveCmd.options.get("path")
if (path.isDefined) {
val deltaLog = DeltaLog.forTable(SparkSession.active, path.get, saveCmd.options)
RapidsDeltaUtils.tagForDeltaWrite(this, saveCmd.query.schema, deltaLog, saveCmd.options,
SparkSession.active)
RapidsDeltaUtils.tagForDeltaWrite(this, saveCmd.query.schema, Some(deltaLog),
saveCmd.options, SparkSession.active)
} else {
willNotWorkOnGpu("no path specified for Delta Lake table")
}
Expand All @@ -131,5 +177,5 @@ class DeltaCreatableRelationProviderMeta(
*/
class DeltaProbeImpl extends DeltaProbe {
// Delta Lake is built-in for Databricks instances, so no probing is necessary.
override def getDeltaProvider: DeltaProvider = DeltaProviderImpl
override def getDeltaProvider: DeltaProvider = DatabricksDeltaProvider
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DeleteCommandMeta(
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
DeleteCommandMetaShim.tagForGpu(this)
RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, deleteCmd.deltaLog,
RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, Some(deleteCmd.deltaLog),
Map.empty, SparkSession.active)
}

Expand All @@ -62,7 +62,7 @@ class DeleteCommandEdgeMeta(
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
DeleteCommandMetaShim.tagForGpu(this)
RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, deleteCmd.deltaLog,
RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, Some(deleteCmd.deltaLog),
Map.empty, SparkSession.active)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class MergeIntoCommandMeta(
MergeIntoCommandMetaShim.tagForGpu(this, mergeCmd)
val targetSchema = mergeCmd.migratedSchema.getOrElse(mergeCmd.target.schema)
val deltaLog = mergeCmd.targetFileIndex.deltaLog
RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, deltaLog, Map.empty, SparkSession.active)
RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, Some(deltaLog), Map.empty,
SparkSession.active)
}

override def convertToGpu(): RunnableCommand =
Expand All @@ -60,7 +61,8 @@ class MergeIntoCommandEdgeMeta(
MergeIntoCommandMetaShim.tagForGpu(this, mergeCmd)
val targetSchema = mergeCmd.migratedSchema.getOrElse(mergeCmd.target.schema)
val deltaLog = mergeCmd.targetFileIndex.deltaLog
RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, deltaLog, Map.empty, SparkSession.active)
RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, Some(deltaLog), Map.empty,
SparkSession.active)
}

override def convertToGpu(): RunnableCommand =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ object RapidsDeltaUtils {
def tagForDeltaWrite(
meta: RapidsMeta[_, _, _],
schema: StructType,
deltaLog: DeltaLog,
deltaLog: Option[DeltaLog],
options: Map[String, String],
spark: SparkSession): Unit = {
FileFormatChecks.tag(meta, schema, DeltaFormatType, WriteFileOp)
val format = DeltaLogShim.fileFormat(deltaLog)
if (format.getClass == classOf[DeltaParquetFileFormat]) {
val format = deltaLog.map(log => DeltaLogShim.fileFormat(log).getClass)
.getOrElse(classOf[DeltaParquetFileFormat])
if (format == classOf[DeltaParquetFileFormat]) {
GpuParquetFileFormat.tagGpuSupport(meta, spark, options, schema)
} else {
meta.willNotWorkOnGpu(s"file format $format is not supported")
Expand All @@ -45,7 +46,7 @@ object RapidsDeltaUtils {
private def checkIncompatibleConfs(
meta: RapidsMeta[_, _, _],
schema: StructType,
deltaLog: DeltaLog,
deltaLog: Option[DeltaLog],
sqlConf: SQLConf,
options: Map[String, String]): Unit = {
def getSQLConf(key: String): Option[String] = {
Expand All @@ -65,19 +66,21 @@ object RapidsDeltaUtils {
orderableTypeSig.isSupportedByPlugin(t)
}
if (unorderableTypes.nonEmpty) {
val metadata = DeltaLogShim.getMetadata(deltaLog)
val hasPartitioning = metadata.partitionColumns.nonEmpty ||
val metadata = deltaLog.map(log => DeltaLogShim.getMetadata(log))
val hasPartitioning = metadata.exists(_.partitionColumns.nonEmpty) ||
options.get(DataSourceUtils.PARTITIONING_COLUMNS_KEY).exists(_.nonEmpty)
if (!hasPartitioning) {
val optimizeWriteEnabled = {
val deltaOptions = new DeltaOptions(options, sqlConf)
deltaOptions.optimizeWrite.orElse {
getSQLConf("spark.databricks.delta.optimizeWrite.enabled").map(_.toBoolean).orElse {
DeltaConfigs.AUTO_OPTIMIZE.fromMetaData(metadata).orElse {
metadata.configuration.get("delta.autoOptimize.optimizeWrite").orElse {
getSQLConf(
"spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite")
}.map(_.toBoolean)
metadata.flatMap { m =>
DeltaConfigs.AUTO_OPTIMIZE.fromMetaData(m).orElse {
m.configuration.get("delta.autoOptimize.optimizeWrite").orElse {
getSQLConf(
"spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite")
}.map(_.toBoolean)
}
}
}
}.getOrElse(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class UpdateCommandMeta(
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
RapidsDeltaUtils.tagForDeltaWrite(this, updateCmd.target.schema,
updateCmd.tahoeFileIndex.deltaLog, Map.empty, updateCmd.tahoeFileIndex.spark)
Some(updateCmd.tahoeFileIndex.deltaLog), Map.empty, updateCmd.tahoeFileIndex.spark)
}

override def convertToGpu(): RunnableCommand = {
Expand All @@ -62,7 +62,7 @@ class UpdateCommandEdgeMeta(
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
RapidsDeltaUtils.tagForDeltaWrite(this, updateCmd.target.schema,
updateCmd.tahoeFileIndex.deltaLog, Map.empty, updateCmd.tahoeFileIndex.spark)
Some(updateCmd.tahoeFileIndex.deltaLog), Map.empty, updateCmd.tahoeFileIndex.spark)
}

override def convertToGpu(): RunnableCommand = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,20 @@

package com.nvidia.spark.rapids.delta

import scala.collection.JavaConverters.mapAsScalaMapConverter
import scala.util.Try

import com.nvidia.spark.rapids._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.StagingTableCatalog
import org.apache.spark.sql.delta.{DeltaLog, DeltaParquetFileFormat}
import org.apache.spark.sql.delta.catalog.DeltaCatalog
import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim
import org.apache.spark.sql.delta.sources.DeltaDataSource
import org.apache.spark.sql.delta.sources.{DeltaDataSource, DeltaSourceUtils}
import org.apache.spark.sql.execution.datasources.{FileFormat, SaveIntoDataSourceCommand}
import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.ExternalSource
import org.apache.spark.sql.rapids.execution.UnshimmedTrampolineUtil
import org.apache.spark.sql.sources.CreatableRelationProvider
Expand All @@ -48,6 +53,28 @@ abstract class DeltaIOProvider extends DeltaProviderImplBase {
override def isSupportedFormat(format: Class[_ <: FileFormat]): Boolean = {
format == classOf[DeltaParquetFileFormat]
}

override def isSupportedCatalog(catalogClass: Class[_ <: StagingTableCatalog]): Boolean = {
catalogClass == classOf[DeltaCatalog]
}

override def tagForGpu(
cpuExec: AtomicCreateTableAsSelectExec,
meta: AtomicCreateTableAsSelectExecMeta): Unit = {
require(isSupportedCatalog(cpuExec.catalog.getClass))
if (!meta.conf.isDeltaWriteEnabled) {
meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " +
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
val properties = cpuExec.properties
val provider = properties.getOrElse("provider",
cpuExec.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME))
if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) {
meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider")
}
RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None,
cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session)
}
}

class DeltaCreatableRelationProviderMeta(
Expand All @@ -71,8 +98,8 @@ class DeltaCreatableRelationProviderMeta(
val path = saveCmd.options.get("path")
if (path.isDefined) {
val deltaLog = DeltaLog.forTable(SparkSession.active, path.get, saveCmd.options)
RapidsDeltaUtils.tagForDeltaWrite(this, saveCmd.query.schema, deltaLog, saveCmd.options,
SparkSession.active)
RapidsDeltaUtils.tagForDeltaWrite(this, saveCmd.query.schema, Some(deltaLog),
saveCmd.options, SparkSession.active)
} else {
willNotWorkOnGpu("no path specified for Delta Lake table")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ object RapidsDeltaUtils {
def tagForDeltaWrite(
meta: RapidsMeta[_, _, _],
schema: StructType,
deltaLog: DeltaLog,
deltaLog: Option[DeltaLog],
options: Map[String, String],
spark: SparkSession): Unit = {
FileFormatChecks.tag(meta, schema, DeltaFormatType, WriteFileOp)
val format = DeltaRuntimeShim.fileFormatFromLog(deltaLog)
if (format.getClass == classOf[DeltaParquetFileFormat]) {
val format = deltaLog.map(log => DeltaRuntimeShim.fileFormatFromLog(log).getClass)
.getOrElse(classOf[DeltaParquetFileFormat])
if (format == classOf[DeltaParquetFileFormat]) {
GpuParquetFileFormat.tagGpuSupport(meta, spark, options, schema)
} else {
meta.willNotWorkOnGpu(s"file format $format is not supported")
Expand All @@ -43,7 +44,7 @@ object RapidsDeltaUtils {

private def checkIncompatibleConfs(
meta: RapidsMeta[_, _, _],
deltaLog: DeltaLog,
deltaLog: Option[DeltaLog],
sqlConf: SQLConf,
options: Map[String, String]): Unit = {
def getSQLConf(key: String): Option[String] = {
Expand All @@ -58,11 +59,13 @@ object RapidsDeltaUtils {
val deltaOptions = new DeltaOptions(options, sqlConf)
deltaOptions.optimizeWrite.orElse {
getSQLConf("spark.databricks.delta.optimizeWrite.enabled").map(_.toBoolean).orElse {
val metadata = DeltaRuntimeShim.unsafeVolatileSnapshotFromLog(deltaLog).metadata
DeltaConfigs.AUTO_OPTIMIZE.fromMetaData(metadata).orElse {
metadata.configuration.get("delta.autoOptimize.optimizeWrite").orElse {
getSQLConf("spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite")
}.map(_.toBoolean)
deltaLog.flatMap { log =>
val metadata = DeltaRuntimeShim.unsafeVolatileSnapshotFromLog(log).metadata
DeltaConfigs.AUTO_OPTIMIZE.fromMetaData(metadata).orElse {
metadata.configuration.get("delta.autoOptimize.optimizeWrite").orElse {
getSQLConf("spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite")
}.map(_.toBoolean)
}
}
}
}.getOrElse(false)
Expand All @@ -73,9 +76,11 @@ object RapidsDeltaUtils {

val autoCompactEnabled =
getSQLConf("spark.databricks.delta.autoCompact.enabled").orElse {
val metadata = DeltaRuntimeShim.unsafeVolatileSnapshotFromLog(deltaLog).metadata
metadata.configuration.get("delta.autoOptimize.autoCompact").orElse {
getSQLConf("spark.databricks.delta.properties.defaults.autoOptimize.autoCompact")
deltaLog.flatMap { log =>
val metadata = DeltaRuntimeShim.unsafeVolatileSnapshotFromLog(log).metadata
metadata.configuration.get("delta.autoOptimize.autoCompact").orElse {
getSQLConf("spark.databricks.delta.properties.defaults.autoOptimize.autoCompact")
}
}
}.exists(_.toBoolean)
if (autoCompactEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import com.nvidia.spark.rapids.{RapidsConf, ShimReflectionUtils, VersionUtils}
import com.nvidia.spark.rapids.delta.DeltaProvider

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.StagingTableCatalog
import org.apache.spark.sql.delta.{DeltaLog, DeltaUDF, Snapshot}
import org.apache.spark.sql.delta.catalog.DeltaCatalog
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.util.Clock
Expand All @@ -35,6 +37,8 @@ trait DeltaRuntimeShim {
def fileFormatFromLog(deltaLog: DeltaLog): FileFormat

def getTightBoundColumnOnFileInitDisabled(spark: SparkSession): Boolean

def getGpuDeltaCatalog(cpuCatalog: DeltaCatalog, rapidsConf: RapidsConf): StagingTableCatalog
}

object DeltaRuntimeShim {
Expand Down Expand Up @@ -81,4 +85,8 @@ object DeltaRuntimeShim {

def getTightBoundColumnOnFileInitDisabled(spark: SparkSession): Boolean =
shimInstance.getTightBoundColumnOnFileInitDisabled(spark)

def getGpuDeltaCatalog(cpuCatalog: DeltaCatalog, rapidsConf: RapidsConf): StagingTableCatalog = {
shimInstance.getGpuDeltaCatalog(cpuCatalog, rapidsConf)
}
}
Loading