Skip to content

Commit

Permalink
Merge branch 'apache:master' into char_varchar
Browse files Browse the repository at this point in the history
  • Loading branch information
jovanm-db authored Nov 19, 2024
2 parents 9918010 + f1b68d8 commit 6dfdad4
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.SupportsWrite
import org.apache.spark.sql.connector.write.V1Write
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.sources.InsertableRelation

/**
Expand Down Expand Up @@ -58,14 +59,27 @@ case class OverwriteByExpressionExecV1(
sealed trait V1FallbackWriters extends LeafV2CommandExec with SupportsV1Write {
override def output: Seq[Attribute] = Nil

override val metrics: Map[String, SQLMetric] =
write.supportedCustomMetrics().map { customMetric =>
customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric)
}.toMap

def table: SupportsWrite
def refreshCache: () => Unit
def write: V1Write

override def run(): Seq[InternalRow] = {
val writtenRows = writeWithV1(write.toInsertableRelation)
writeWithV1(write.toInsertableRelation)
refreshCache()
writtenRows

write.reportDriverMetrics().foreach { customTaskMetric =>
metrics.get(customTaskMetric.name()).foreach(_.set(customTaskMetric.value()))
}

val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)

Nil
}
}

Expand All @@ -75,8 +89,7 @@ sealed trait V1FallbackWriters extends LeafV2CommandExec with SupportsV1Write {
trait SupportsV1Write extends SparkPlan {
def plan: LogicalPlan

protected def writeWithV1(relation: InsertableRelation): Seq[InternalRow] = {
protected def writeWithV1(relation: InsertableRelation): Unit = {
relation.insert(Dataset.ofRows(session, plan), overwrite = false)
Nil
}
}
25 changes: 24 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ import org.scalatest.Assertions
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -447,6 +448,28 @@ object QueryTest extends Assertions {
case None =>
}
}

def withPhysicalPlansCaptured(spark: SparkSession, thunk: => Unit): Seq[SparkPlan] = {
var capturedPlans = Seq.empty[SparkPlan]

val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
capturedPlans = capturedPlans :+ qe.executedPlan
}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}

spark.sparkContext.listenerBus.waitUntilEmpty(15000)
spark.listenerManager.register(listener)
try {
thunk
spark.sparkContext.listenerBus.waitUntilEmpty(15000)
} finally {
spark.listenerManager.unregister(listener)
}

capturedPlans
}
}

class QueryTestSuite extends QueryTest with test.SharedSparkSession {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,20 @@ import org.scalatest.BeforeAndAfter

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.QueryTest.withPhysicalPlansCaptured
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1Write, WriteBuilder}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.execution.datasources.v2.{AppendDataExecV1, OverwriteByExpressionExecV1}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.SQLConf.{OPTIMIZER_MAX_ITERATIONS, V2_SESSION_CATALOG_IMPLEMENTATION}
import org.apache.spark.sql.sources._
Expand Down Expand Up @@ -198,6 +202,43 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before
SparkSession.setDefaultSession(spark)
}
}

test("SPARK-50315: metrics for V1 fallback writers") {
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
try {
val session = SparkSession.builder()
.master("local[1]")
.config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName)
.getOrCreate()

def captureWrite(sparkSession: SparkSession)(thunk: => Unit): SparkPlan = {
val physicalPlans = withPhysicalPlansCaptured(sparkSession, thunk)
val v1FallbackWritePlans = physicalPlans.filter {
case _: AppendDataExecV1 | _: OverwriteByExpressionExecV1 => true
case _ => false
}

assert(v1FallbackWritePlans.size === 1)
v1FallbackWritePlans.head
}

val appendPlan = captureWrite(session) {
val df = session.createDataFrame(Seq((1, "x")))
df.write.mode("append").option("name", "t1").format(v2Format).saveAsTable("test")
}
assert(appendPlan.metrics("numOutputRows").value === 1)

val overwritePlan = captureWrite(session) {
val df2 = session.createDataFrame(Seq((2, "y")))
df2.writeTo("test").overwrite(lit(true))
}
assert(overwritePlan.metrics("numOutputRows").value === 1)
} finally {
SparkSession.setActiveSession(spark)
SparkSession.setDefaultSession(spark)
}
}
}

class V1WriteFallbackSessionCatalogSuite
Expand Down Expand Up @@ -376,10 +417,23 @@ class InMemoryTableWithV1Fallback(
}

override def build(): V1Write = new V1Write {
case class SupportedV1WriteMetric(name: String, description: String) extends CustomSumMetric

override def supportedCustomMetrics(): Array[CustomMetric] =
Array(SupportedV1WriteMetric("numOutputRows", "Number of output rows"))

private var writeMetrics = Array.empty[CustomTaskMetric]

override def reportDriverMetrics(): Array[CustomTaskMetric] = writeMetrics

override def toInsertableRelation: InsertableRelation = {
(data: DataFrame, overwrite: Boolean) => {
assert(!overwrite, "V1 write fallbacks cannot be called with overwrite=true")
val rows = data.collect()

case class V1WriteTaskMetric(name: String, value: Long) extends CustomTaskMetric
writeMetrics = Array(V1WriteTaskMetric("numOutputRows", rows.length))

rows.groupBy(getPartitionValues).foreach { case (partition, elements) =>
if (dataMap.contains(partition) && mode == "append") {
dataMap.put(partition, dataMap(partition) ++ elements)
Expand Down

0 comments on commit 6dfdad4

Please sign in to comment.