Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-947] Add a whole stage fallback strategy (#948)
Browse files Browse the repository at this point in the history
* Initial commit

* Move to add the strategy in ColumnarOverrideRules

* Revert "Initial commit"

This reverts commit c4210ccfe39fac9b144117ec403bd9b00d8b5703.

* Consider stage boundary

* Differentiate the handling for spark raw BatchScanExec & arrow datasource BatchScanExec

* Handle InMemoryTableScanExec

* Use spark's transition code

* Support fallback for LocalWindowExec

* Enable this feature when AQE is on

* Add a config

* Check whether AQE is supported

* Set default value to -1
  • Loading branch information
PHILO-HE authored Sep 8, 2022
1 parent 3ee659c commit 890df14
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class GazellePluginConfig(conf: SQLConf) extends Logging {
conf.getConfString(
ENABLE_HASH_AGG_FOR_STRING_TYPE_KEY, "true").toBoolean && enableCpu

// If a stage has >= threshold fallbacks after replacing with columnar operators, just let
// the whole stage fallback. This feature will be turned off if -1 is specified.
val WHOLE_STAGE_FALLBACK_THRESHOLD_KEY = "spark.oap.sql.columnar.wholeStage.fallback.threshold"
val WHOLE_STAGE_FALLBACK_THRESHOLD: Int =
conf.getConfString(WHOLE_STAGE_FALLBACK_THRESHOLD_KEY, "-1").toInt

// enable or disable columnar project and filter
val enableColumnarProjFilter: Boolean =
conf.getConfString("spark.oap.sql.columnar.projfilter", "true").toBoolean && enableCpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import com.intel.oap.sql.shims.SparkShimLoader

import org.apache.spark.{MapOutputStatistics, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.LocalWindowExec
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.BuildLeft
Expand Down Expand Up @@ -548,30 +549,26 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit
def preOverrides = ColumnarPreOverrides(session)
def postOverrides = ColumnarPostOverrides()

def columnarWholeStageEnabled = conf.getBoolean("spark.oap.sql.columnar.wholestagecodegen", defaultValue = true) && !codegendisable
def columnarWholeStageEnabled = conf.getBoolean(
"spark.oap.sql.columnar.wholestagecodegen", defaultValue = true) && !codegendisable
def collapseOverrides = ColumnarCollapseCodegenStages(columnarWholeStageEnabled)
def enableArrowColumnarToRow: Boolean =
conf.getBoolean("spark.oap.sql.columnar.columnartorow", defaultValue = true)
def wholeStageFallbackThreshold: Int =
conf.getInt("spark.oap.sql.columnar.wholeStage.fallback.threshold", defaultValue = -1)

var isSupportAdaptive: Boolean = true

private def supportAdaptive(plan: SparkPlan): Boolean = {
// TODO migrate dynamic-partition-pruning onto adaptive execution.
// Only QueryStage will have Exchange as Leaf Plan
val isLeafPlanExchange = plan match {
case e: Exchange => true
case other => false
}
isLeafPlanExchange || (SQLConf.get.adaptiveExecutionEnabled && (sanityCheck(plan) &&
!plan.logicalLink.exists(_.isStreaming) &&
!plan.expressions.exists(_.find(_.isInstanceOf[DynamicPruningSubquery]).isDefined) &&
plan.children.forall(supportAdaptive)))
}

private def sanityCheck(plan: SparkPlan): Boolean =
plan.logicalLink.isDefined
var originalPlan: SparkPlan = _
var fallbacks = 0

override def preColumnarTransitions: Rule[SparkPlan] = plan => {
if (columnarEnabled) {
isSupportAdaptive = supportAdaptive(plan)
// According to Spark's Columnar.scala, the plan is tackled one by one.
// By recording the original plan, we can easily let the whole stage
// fallback at #postColumnarTransitions.
originalPlan = plan
isSupportAdaptive = SparkShimLoader.getSparkShims.supportAdaptiveWithExchangeConsidered(plan)
val rule = preOverrides
rule.setAdaptiveSupport(isSupportAdaptive)
rule(rowGuardOverrides(plan))
Expand All @@ -580,18 +577,82 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit
}
}

def checkColumnarToRow(plan: SparkPlan): Unit = {
plan match {
case _: ColumnarToRowExec =>
fallbacks = fallbacks + 1
case _ =>
}
plan.children.map(plan => checkColumnarToRow(plan))
}

def fallbackWholeStage(plan: SparkPlan): Boolean = {
if (wholeStageFallbackThreshold == -1) {
return false
}
fallbacks = 0
checkColumnarToRow(plan)
if (fallbacks >= wholeStageFallbackThreshold) {
true
} else {
false
}
}

/**
* Ported from ApplyColumnarRulesAndInsertTransitions of Spark.
* Inserts an transition to columnar formatted data.
*/
private def insertRowToColumnar(plan: SparkPlan): SparkPlan = {
if (!plan.supportsColumnar) {
// The tree feels kind of backwards
// Columnar Processing will start here, so transition from row to columnar
RowToColumnarExec(insertTransitions(plan, outputsColumnar = false))
} else if (!plan.isInstanceOf[RowToColumnarTransition]) {
plan.withNewChildren(plan.children.map(insertRowToColumnar))
} else {
plan
}
}

/**
* Ported from ApplyColumnarRulesAndInsertTransitions of Spark.
* Inserts RowToColumnarExecs and ColumnarToRowExecs where needed.
*/
private def insertTransitions(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = {
if (outputsColumnar) {
insertRowToColumnar(plan)
} else if (plan.supportsColumnar) {
// `outputsColumnar` is false but the plan outputs columnar format, so add a
// to-row transition here.
ColumnarToRowExec(insertRowToColumnar(plan))
} else if (!plan.isInstanceOf[ColumnarToRowTransition]) {
plan.withNewChildren(plan.children.map(insertTransitions(_, outputsColumnar = false)))
} else {
plan
}
}

override def postColumnarTransitions: Rule[SparkPlan] = plan => {
if (columnarEnabled) {
val rule = postOverrides
rule.setAdaptiveSupport(isSupportAdaptive)
val tmpPlan = rule(plan)
val ret = collapseOverrides(tmpPlan)
if (codegendisable)
{
logDebug("postColumnarTransitions: resetting spark.oap.sql.columnar.codegendisableforsmallshuffles To false")
session.sqlContext.setConf("spark.oap.sql.columnar.codegendisableforsmallshuffles", "false")
if (isSupportAdaptive && fallbackWholeStage(plan)) {
// BatchScan with ArrowScan initialized can still connect
// to ColumnarToRow for transition.
insertTransitions(originalPlan, false)
} else {
val rule = postOverrides
rule.setAdaptiveSupport(isSupportAdaptive)
val tmpPlan = rule(plan)
val ret = collapseOverrides(tmpPlan)
if (codegendisable)
{
logDebug("postColumnarTransitions:" +
" resetting spark.oap.sql.columnar.codegendisableforsmallshuffles To false")
session.sqlContext.setConf(
"spark.oap.sql.columnar.codegendisableforsmallshuffles", "false")
}
ret
}
ret
} else {
plan
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package com.intel.oap.extension
import com.intel.oap.GazellePluginConfig
import com.intel.oap.GazelleSparkExtensionsInjector
import com.intel.oap.execution.LocalPhysicalWindow

import org.apache.spark.sql.LocalWindowExec
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SparkSessionExtensions, Strategy, execution}
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
Expand Down Expand Up @@ -75,37 +77,6 @@ object JoinSelectionOverrides extends Strategy with JoinSelectionHelper with SQL
}
}

case class LocalWindowExec(
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan)
extends WindowExecBase {

override def output: Seq[Attribute] =
child.output ++ windowExpression.map(_.toAttribute)

override def requiredChildDistribution: Seq[Distribution] = {
super.requiredChildDistribution
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def outputPartitioning: Partitioning = child.outputPartitioning

protected override def doExecute(): RDD[InternalRow] = {
// todo implement this to fall back
throw new UnsupportedOperationException()
}

protected def withNewChildInternal(newChild: SparkPlan):
LocalWindowExec =
copy(child = newChild)
}

object LocalWindowApply extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case LocalPhysicalWindow(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package com.intel.oap.extension.columnar

import com.intel.oap.GazellePluginConfig
import com.intel.oap.execution._
import com.intel.oap.extension.LocalWindowExec
import com.intel.oap.sql.shims.SparkShimLoader

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.LocalWindowExec
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.FullOuter
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, JoinedRow, NamedExpression, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning}
import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan}
import org.apache.spark.sql.execution.window.WindowExecBase

case class LocalWindowExec(
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan)
extends WindowExecBase {

override def output: Seq[Attribute] =
child.output ++ windowExpression.map(_.toAttribute)

override def requiredChildDistribution: Seq[Distribution] = {
super.requiredChildDistribution
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def outputPartitioning: Partitioning = child.outputPartitioning

// This function is copied from Spark's WindowExec.
protected override def doExecute(): RDD[InternalRow] = {
// Unwrap the window expressions and window frame factories from the map.
val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
val spillThreshold = conf.windowExecBufferSpillThreshold

// Start processing.
child.execute().mapPartitions { stream =>
new Iterator[InternalRow] {

// Get all relevant projections.
val result = createResultProjection(expressions)
val grouping = UnsafeProjection.create(partitionSpec, child.output)

// Manage the stream and the grouping.
var nextRow: UnsafeRow = null
var nextGroup: UnsafeRow = null
var nextRowAvailable: Boolean = false
private[this] def fetchNextRow(): Unit = {
nextRowAvailable = stream.hasNext
if (nextRowAvailable) {
nextRow = stream.next().asInstanceOf[UnsafeRow]
nextGroup = grouping(nextRow)
} else {
nextRow = null
nextGroup = null
}
}
fetchNextRow()

// Manage the current partition.
val buffer: ExternalAppendOnlyUnsafeRowArray =
new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)

var bufferIterator: Iterator[UnsafeRow] = _

val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType))
val frames = factories.map(_(windowFunctionResult))
val numFrames = frames.length
private[this] def fetchNextPartition(): Unit = {
// Collect all the rows in the current partition.
// Before we start to fetch new input rows, make a copy of nextGroup.
val currentGroup = nextGroup.copy()

// clear last partition
buffer.clear()

while (nextRowAvailable && nextGroup == currentGroup) {
buffer.add(nextRow)
fetchNextRow()
}

// Setup the frames.
var i = 0
while (i < numFrames) {
frames(i).prepare(buffer)
i += 1
}

// Setup iteration
rowIndex = 0
bufferIterator = buffer.generateIterator()
}

// Iteration
var rowIndex = 0

override final def hasNext: Boolean =
(bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable

val join = new JoinedRow
override final def next(): InternalRow = {
// Load the next partition if we need to.
if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) {
fetchNextPartition()
}

if (bufferIterator.hasNext) {
val current = bufferIterator.next()

// Get the results for the window frames.
var i = 0
while (i < numFrames) {
frames(i).write(rowIndex, current)
i += 1
}

// 'Merge' the input row with the window function result
join(current, windowFunctionResult)
rowIndex += 1

// Return the projection.
result(join)
} else {
throw new NoSuchElementException
}
}
}
}
}

protected def withNewChildInternal(newChild: SparkPlan):
LocalWindowExec =
copy(child = newChild)
}
Loading

0 comments on commit 890df14

Please sign in to comment.