Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-542] Add rule to propagate local window for rank + filter pattern (
Browse files Browse the repository at this point in the history
#545)

Closes #542
  • Loading branch information
zhztheplayer authored Nov 5, 2021
1 parent 1cad018 commit e710234
Show file tree
Hide file tree
Showing 9 changed files with 395 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,18 @@ package com.intel.oap
import java.util
import java.util.Collections
import java.util.Objects

import scala.language.implicitConversions

import com.intel.oap.GazellePlugin.GAZELLE_SESSION_EXTENSION_NAME
import com.intel.oap.GazellePlugin.SPARK_SESSION_EXTS_KEY
import com.intel.oap.extension.StrategyOverrides

import com.intel.oap.extension.{OptimizerOverrides, StrategyOverrides}
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.api.plugin.DriverPlugin
import org.apache.spark.api.plugin.ExecutorPlugin
import org.apache.spark.api.plugin.PluginContext
import org.apache.spark.api.plugin.SparkPlugin
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}

class GazellePlugin extends SparkPlugin {
override def driverPlugin(): DriverPlugin = {
Expand Down Expand Up @@ -87,22 +84,32 @@ private[oap] trait GazelleSparkExtensionsInjector {
}

private[oap] object GazellePlugin {

val LOCAL_OVERRIDDEN_CLASSES: Seq[Class[_]] = Seq()

initialLocalOverriddenClasses()

// To enable GazellePlugin in production, set "spark.plugins=com.intel.oap.GazellePlugin"
val SPARK_SQL_PLUGINS_KEY: String = "spark.plugins"
val GAZELLE_PLUGIN_NAME: String = Objects.requireNonNull(classOf[GazellePlugin]
.getCanonicalName)
val SPARK_SESSION_EXTS_KEY: String = StaticSQLConf.SPARK_SESSION_EXTENSIONS.key
val GAZELLE_SESSION_EXTENSION_NAME: String = Objects.requireNonNull(
classOf[GazelleSessionExtensions].getCanonicalName)

/**
* Specify all injectors that Gazelle is using in following list.
*/
val DEFAULT_INJECTORS: List[GazelleSparkExtensionsInjector] = List(
ColumnarOverrides,
OptimizerOverrides,
StrategyOverrides
)

def initialLocalOverriddenClasses(): Unit = {
LOCAL_OVERRIDDEN_CLASSES.foreach(clazz =>
GazellePlugin.getClass.getClassLoader.loadClass(clazz.getName))
}

implicit def sparkConfImplicit(conf: SparkConf): SparkConfImplicits = {
new SparkConfImplicits(conf)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class GazellePluginConfig(conf: SQLConf) extends Logging {
val lines =
try source.mkString
finally source.close()
return true
//TODO(): check CPU flags to enable/disable AVX512
if (lines.contains("GenuineIntel")) {
return true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,19 @@
package com.intel.oap.execution

import java.util.concurrent.TimeUnit

import com.google.flatbuffers.FlatBufferBuilder
import com.intel.oap.GazellePluginConfig
import com.intel.oap.expression.{CodeGeneration, ConverterUtils}
import com.intel.oap.vectorized.{ArrowWritableColumnVector, CloseableColumnBatchIterator, ExpressionEvaluator}
import org.apache.arrow.gandiva.expression.TreeBuilder
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, Literal, MakeDecimal, NamedExpression, PredicateHelper, Rank, SortOrder, UnscaledValue, WindowExpression, WindowFunction, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{SortExec, SparkPlan}
Expand All @@ -44,17 +42,19 @@ import org.apache.spark.sql.types.{DataType, DateType, DecimalType, DoubleType,
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ExecutorManager

import scala.collection.JavaConverters._
import scala.collection.immutable.Stream.Empty
import scala.collection.mutable.ListBuffer
import scala.util.Random

import org.apache.spark.sql.execution.datasources.v2.arrow.SparkSchemaUtils

import util.control.Breaks._

case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
isLocal: Boolean,
child: SparkPlan) extends WindowExecBase {

override def supportsColumnar: Boolean = true
Expand All @@ -64,6 +64,10 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
buildCheck()

override def requiredChildDistribution: Seq[Distribution] = {
if (isLocal) {
// localized window doesn't require distribution
return Seq.fill(children.size)(UnspecifiedDistribution)
}
if (partitionSpec.isEmpty) {
// Only show warning when the number of bytes is larger than 100 MiB?
logWarning("No Partition Defined for Window operation! Moving all data to a single "
Expand Down Expand Up @@ -421,7 +425,7 @@ object ColumnarWindowExec extends Logging {
}

override def apply(plan: SparkPlan): SparkPlan = plan transformUp {
case p @ ColumnarWindowExec(windowExpression, partitionSpec, orderSpec, child) =>
case p @ ColumnarWindowExec(windowExpression, partitionSpec, orderSpec, isLocalized, child) =>
val windows = ListBuffer[NamedExpression]()
val inProjectExpressions = ListBuffer[NamedExpression]()
val outProjectExpressions = windowExpression.map(e => e.asInstanceOf[Alias])
Expand All @@ -431,7 +435,8 @@ object ColumnarWindowExec extends Logging {
}
val inputProject = ColumnarConditionProjectExec(null,
child.output ++ inProjectExpressions, child)
val window = new ColumnarWindowExec(windows, partitionSpec, orderSpec, inputProject)
val window = new ColumnarWindowExec(windows, partitionSpec, orderSpec, isLocalized,
inputProject)
val outputProject = ColumnarConditionProjectExec(null,
child.output ++ outProjectExpressions, window)
outputProject
Expand All @@ -440,14 +445,14 @@ object ColumnarWindowExec extends Logging {

object RemoveSort extends Rule[SparkPlan] with PredicateHelper {
override def apply(plan: SparkPlan): SparkPlan = plan transform {
case p1 @ ColumnarWindowExec(_, _, _, p2 @ (_: SortExec | _: ColumnarSortExec)) =>
case p1 @ ColumnarWindowExec(_, _, _, _, p2 @ (_: SortExec | _: ColumnarSortExec)) =>
p1.withNewChildren(p2.children)
}
}

object RemoveCoalesceBatches extends Rule[SparkPlan] with PredicateHelper {
override def apply(plan: SparkPlan): SparkPlan = plan transform {
case p1 @ ColumnarWindowExec(_, _, _, p2: CoalesceBatchesExec) =>
case p1 @ ColumnarWindowExec(_, _, _, _, p2: CoalesceBatchesExec) =>
p1.withNewChildren(p2.children)
}
}
Expand Down Expand Up @@ -516,11 +521,13 @@ object ColumnarWindowExec extends Logging {
def createWithOptimizations(windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
isLocalized: Boolean,
child: SparkPlan): SparkPlan = {
val columnar = new ColumnarWindowExec(
windowExpression,
partitionSpec,
orderSpec,
isLocalized,
child)
ColumnarWindowExec.optimize(columnar)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.oap.execution

import com.intel.oap.extension.LocalRankWindow
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, SortOrder, WindowFunctionType}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window}

object LocalPhysicalWindow {
// windowFunctionType, windowExpression, partitionSpec, orderSpec, child
private type ReturnType =
(WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan)

def unapply(a: Any): Option[ReturnType] = a match {
case expr @ Window(windowExpressions, partitionSpec, orderSpec, child) =>

// The window expression should not be empty here, otherwise it's a bug.
if (windowExpressions.isEmpty) {
throw new IllegalArgumentException(s"Window expression is empty in $expr")
}

if (!windowExpressions.exists(expr => {
expr.isInstanceOf[Alias] &&
LocalRankWindow.isLocalWindowColumnName(expr.asInstanceOf[Alias].name)
})) {
return None
}

val windowFunctionType = windowExpressions.map(WindowFunctionType.functionType)
.reduceLeft { (t1: WindowFunctionType, t2: WindowFunctionType) =>
if (t1 != t2) {
// We shouldn't have different window function type here, otherwise it's a bug.
throw new IllegalArgumentException(
s"Found different window function type in $windowExpressions")
} else {
t1
}
}

Some((windowFunctionType, windowExpressions, partitionSpec, orderSpec, child))

case _ => None
}
}

object Patterns {

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
package com.intel.oap

import com.intel.oap.execution._
import com.intel.oap.extension.LocalWindowExec
import com.intel.oap.extension.columnar.ColumnarGuardRule
import com.intel.oap.extension.columnar.RowGuard
import com.intel.oap.sql.execution.RowToArrowColumnarExec

import org.apache.spark.internal.config._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
Expand Down Expand Up @@ -235,12 +235,26 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
plan.windowExpression,
plan.partitionSpec,
plan.orderSpec,
isLocalized = false,
replaceWithColumnarPlan(plan.child))
} catch {
case _: Throwable =>
logInfo("Columnar Window: Falling back to regular Window...")
plan
}
case plan: LocalWindowExec =>
try {
ColumnarWindowExec.createWithOptimizations(
plan.windowExpression,
plan.partitionSpec,
plan.orderSpec,
isLocalized = true,
replaceWithColumnarPlan(plan.child))
} catch {
case _: Throwable =>
logInfo("Localized Columnar Window: Falling back to regular Window...")
plan
}
case p =>
val children = plan.children.map(replaceWithColumnarPlan)
logDebug(s"Columnar Processing for ${p.getClass} is currently not supported.")
Expand Down
Loading

0 comments on commit e710234

Please sign in to comment.