diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 39a16e917c4a5..f3ca4f06cd372 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -20,6 +20,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{UserDefinedType, _} import org.apache.spark.unsafe.Platform @@ -33,6 +34,15 @@ import org.apache.spark.unsafe.Platform class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection { import InterpretedUnsafeProjection._ + private[this] val subExprEliminationEnabled = SQLConf.get.subexpressionEliminationEnabled + private[this] lazy val runtime = + new SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries) + private[this] val exprs = if (subExprEliminationEnabled) { + runtime.proxyExpressions(expressions) + } else { + expressions.toSeq + } + /** Number of (top level) fields in the resulting row. */ private[this] val numFields = expressions.length @@ -63,17 +73,21 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe } override def initialize(partitionIndex: Int): Unit = { - expressions.foreach(_.foreach { + exprs.foreach(_.foreach { case n: Nondeterministic => n.initialize(partitionIndex) case _ => }) } override def apply(row: InternalRow): UnsafeRow = { + if (subExprEliminationEnabled) { + runtime.setInput(row) + } + // Put the expression results in the intermediate row. var i = 0 while (i < numFields) { - values(i) = expressions(i).eval(row) + values(i) = exprs(i).eval(row) i += 1 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala new file mode 100644 index 0000000000000..3189d81289903 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import java.util.IdentityHashMap + +import scala.collection.JavaConverters._ + +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.DataType + +/** + * This class helps subexpression elimination for interpreted evaluation + * such as `InterpretedUnsafeProjection`. It maintains an evaluation cache. + * This class wraps `ExpressionProxy` around given expressions. The `ExpressionProxy` + * intercepts expression evaluation and loads from the cache first. + */ +class SubExprEvaluationRuntime(cacheMaxEntries: Int) { + // The id assigned to `ExpressionProxy`. `SubExprEvaluationRuntime` will use assigned ids of + // `ExpressionProxy` to decide the equality when loading from cache. `SubExprEvaluationRuntime` + // won't be use by multi-threads so we don't need to consider concurrency here. + private var proxyExpressionCurrentId = 0 + + private[sql] val cache: LoadingCache[ExpressionProxy, ResultProxy] = CacheBuilder.newBuilder() + .maximumSize(cacheMaxEntries) + .build( + new CacheLoader[ExpressionProxy, ResultProxy]() { + override def load(expr: ExpressionProxy): ResultProxy = { + ResultProxy(expr.proxyEval(currentInput)) + } + }) + + private var currentInput: InternalRow = null + + def getEval(proxy: ExpressionProxy): Any = try { + cache.get(proxy).result + } catch { + // Cache.get() may wrap the original exception. See the following URL + // http://google.github.io/guava/releases/14.0/api/docs/com/google/common/cache/ + // Cache.html#get(K,%20java.util.concurrent.Callable) + case e @ (_: UncheckedExecutionException | _: ExecutionError) => + throw e.getCause + } + + /** + * Sets given input row as current row for evaluating expressions. This cleans up the cache + * too as new input comes. + */ + def setInput(input: InternalRow = null): Unit = { + currentInput = input + cache.invalidateAll() + } + + /** + * Recursively replaces expression with its proxy expression in `proxyMap`. + */ + private def replaceWithProxy( + expr: Expression, + proxyMap: IdentityHashMap[Expression, ExpressionProxy]): Expression = { + if (proxyMap.containsKey(expr)) { + proxyMap.get(expr) + } else { + expr.mapChildren(replaceWithProxy(_, proxyMap)) + } + } + + /** + * Finds subexpressions and wraps them with `ExpressionProxy`. + */ + def proxyExpressions(expressions: Seq[Expression]): Seq[Expression] = { + val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + + expressions.foreach(equivalentExpressions.addExprTree(_)) + + val proxyMap = new IdentityHashMap[Expression, ExpressionProxy] + + val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) + commonExprs.foreach { e => + val expr = e.head + val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this) + proxyExpressionCurrentId += 1 + + proxyMap.putAll(e.map(_ -> proxy).toMap.asJava) + } + + // Only adding proxy if we find subexpressions. + if (!proxyMap.isEmpty) { + expressions.map(replaceWithProxy(_, proxyMap)) + } else { + expressions + } + } +} + +/** + * A proxy for an catalyst `Expression`. Given a runtime object `SubExprEvaluationRuntime`, + * when this is asked to evaluate, it will load from the evaluation cache in the runtime first. + */ +case class ExpressionProxy( + child: Expression, + id: Int, + runtime: SubExprEvaluationRuntime) extends Expression { + + final override def dataType: DataType = child.dataType + final override def nullable: Boolean = child.nullable + final override def children: Seq[Expression] = child :: Nil + + // `ExpressionProxy` is for interpreted expression evaluation only. So cannot `doGenCode`. + final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException(s"Cannot generate code for expression: $this") + + def proxyEval(input: InternalRow = null): Any = child.eval(input) + + override def eval(input: InternalRow = null): Any = runtime.getEval(this) + + override def equals(obj: Any): Boolean = obj match { + case other: ExpressionProxy => this.id == other.id + case _ => false + } + + override def hashCode(): Int = this.id.hashCode() +} + +/** + * A simple wrapper for holding `Any` in the cache of `SubExprEvaluationRuntime`. + */ +case class ResultProxy(result: Any) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f2e309013a5b6..25e1c6ab517fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -539,6 +539,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SUBEXPRESSION_ELIMINATION_CACHE_MAX_ENTRIES = + buildConf("spark.sql.subexpressionElimination.cache.maxEntries") + .internal() + .doc("The maximum entries of the cache used for interpreted subexpression elimination.") + .version("3.1.0") + .intConf + .checkValue(_ >= 0, "The maximum must not be negative") + .createWithDefault(100) + val CASE_SENSITIVE = buildConf("spark.sql.caseSensitive") .internal() .doc("Whether the query analyzer should be case sensitive or not. " + @@ -3233,6 +3242,9 @@ class SQLConf extends Serializable with Logging { def subexpressionEliminationEnabled: Boolean = getConf(SUBEXPRESSION_ELIMINATION_ENABLED) + def subexpressionEliminationCacheMaxEntries: Int = + getConf(SUBEXPRESSION_ELIMINATION_CACHE_MAX_ENTRIES) + def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD) def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala new file mode 100644 index 0000000000000..badcd4fc3fdad --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.IntegerType + +class SubExprEvaluationRuntimeSuite extends SparkFunSuite { + + test("Evaluate ExpressionProxy should create cached result") { + val runtime = new SubExprEvaluationRuntime(1) + val proxy = ExpressionProxy(Literal(1), 0, runtime) + assert(runtime.cache.size() == 0) + proxy.eval() + assert(runtime.cache.size() == 1) + assert(runtime.cache.get(proxy) == ResultProxy(1)) + } + + test("SubExprEvaluationRuntime cannot exceed configured max entries") { + val runtime = new SubExprEvaluationRuntime(2) + assert(runtime.cache.size() == 0) + + val proxy1 = ExpressionProxy(Literal(1), 0, runtime) + proxy1.eval() + assert(runtime.cache.size() == 1) + assert(runtime.cache.get(proxy1) == ResultProxy(1)) + + val proxy2 = ExpressionProxy(Literal(2), 1, runtime) + proxy2.eval() + assert(runtime.cache.size() == 2) + assert(runtime.cache.get(proxy2) == ResultProxy(2)) + + val proxy3 = ExpressionProxy(Literal(3), 2, runtime) + proxy3.eval() + assert(runtime.cache.size() == 2) + assert(runtime.cache.get(proxy3) == ResultProxy(3)) + } + + test("setInput should empty cached result") { + val runtime = new SubExprEvaluationRuntime(2) + val proxy1 = ExpressionProxy(Literal(1), 0, runtime) + assert(runtime.cache.size() == 0) + proxy1.eval() + assert(runtime.cache.size() == 1) + assert(runtime.cache.get(proxy1) == ResultProxy(1)) + + val proxy2 = ExpressionProxy(Literal(2), 1, runtime) + proxy2.eval() + assert(runtime.cache.size() == 2) + assert(runtime.cache.get(proxy2) == ResultProxy(2)) + + runtime.setInput() + assert(runtime.cache.size() == 0) + } + + test("Wrap ExpressionProxy on subexpressions") { + val runtime = new SubExprEvaluationRuntime(1) + + val one = Literal(1) + val two = Literal(2) + val mul = Multiply(one, two) + val mul2 = Multiply(mul, mul) + val sqrt = Sqrt(mul2) + val sum = Add(mul2, sqrt) + + // ( (one * two) * (one * two) ) + sqrt( (one * two) * (one * two) ) + val proxyExpressions = runtime.proxyExpressions(Seq(sum)) + val proxys = proxyExpressions.flatMap(_.collect { + case p: ExpressionProxy => p + }) + // ( (one * two) * (one * two) ) + assert(proxys.size == 2) + val expected = ExpressionProxy(mul2, 0, runtime) + assert(proxys.forall(_ == expected)) + } + + test("ExpressionProxy won't be on non deterministic") { + val runtime = new SubExprEvaluationRuntime(1) + + val sum = Add(Rand(0), Rand(0)) + val proxys = runtime.proxyExpressions(Seq(sum, sum)).flatMap(_.collect { + case p: ExpressionProxy => p + }) + assert(proxys.isEmpty) + } +} diff --git a/sql/core/benchmarks/SubExprEliminationBenchmark-jdk11-results.txt b/sql/core/benchmarks/SubExprEliminationBenchmark-jdk11-results.txt index 49dc7adccbf3c..3d2b2e5c8edba 100644 --- a/sql/core/benchmarks/SubExprEliminationBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/SubExprEliminationBenchmark-jdk11-results.txt @@ -7,9 +7,9 @@ OpenJDK 64-Bit Server VM 11.0.9+11 on Mac OS X 10.15.6 Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz from_json as subExpr: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -subexpressionElimination off, codegen on 26809 27731 898 0.0 268094225.4 1.0X -subexpressionElimination off, codegen off 25117 26612 1357 0.0 251166638.4 1.1X -subexpressionElimination on, codegen on 2582 2906 282 0.0 25819408.7 10.4X -subexpressionElimination on, codegen off 25635 26131 804 0.0 256346873.1 1.0X +subexpressionElimination off, codegen on 25932 26908 916 0.0 259320042.3 1.0X +subexpressionElimination off, codegen off 26085 26159 65 0.0 260848905.0 1.0X +subexpressionElimination on, codegen on 2860 2939 72 0.0 28603312.9 9.1X +subexpressionElimination on, codegen off 2517 2617 93 0.0 25165157.7 10.3X diff --git a/sql/core/benchmarks/SubExprEliminationBenchmark-results.txt b/sql/core/benchmarks/SubExprEliminationBenchmark-results.txt index 3f131726bc53d..ca2a9c6497500 100644 --- a/sql/core/benchmarks/SubExprEliminationBenchmark-results.txt +++ b/sql/core/benchmarks/SubExprEliminationBenchmark-results.txt @@ -7,9 +7,9 @@ OpenJDK 64-Bit Server VM 1.8.0_265-b01 on Mac OS X 10.15.6 Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz from_json as subExpr: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -subexpressionElimination off, codegen on 24841 25365 803 0.0 248412787.5 1.0X -subexpressionElimination off, codegen off 25344 26205 941 0.0 253442656.5 1.0X -subexpressionElimination on, codegen on 2883 3019 119 0.0 28833086.8 8.6X -subexpressionElimination on, codegen off 24707 25688 903 0.0 247068775.9 1.0X +subexpressionElimination off, codegen on 26503 27622 1937 0.0 265033362.4 1.0X +subexpressionElimination off, codegen off 24920 25376 430 0.0 249196978.2 1.1X +subexpressionElimination on, codegen on 2421 2466 39 0.0 24213606.1 10.9X +subexpressionElimination on, codegen off 2360 2435 87 0.0 23604320.7 11.2X