From fbcfc56b5b50dae9727835026777771b01658473 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Mon, 28 Mar 2022 20:54:07 +0800 Subject: [PATCH] backport runtime filter Signed-off-by: Yuan Zhou --- .../apache/spark/util/sketch/BloomFilter.java | 253 + .../spark/util/sketch/BloomFilterImpl.java | 276 ++ .../expressions/BloomFilterMightContain.scala | 100 + .../aggregate/BloomFilterAggregate.scala | 196 + .../expressions/objects/objects.scala | 1950 ++++++++ .../expressions/regexpExpressions.scala | 909 ++++ .../optimizer/InjectRuntimeFilter.scala | 311 ++ .../sql/catalyst/trees/TreePatterns.scala | 137 + .../spark/sql/execution/SparkOptimizer.scala | 92 + .../apache/spark/sql/internal/SQLConf.scala | 4356 +++++++++++++++++ 10 files changed, 8580 insertions(+) create mode 100644 shims/spark321/src/main/java/org/apache/spark/util/sketch/BloomFilter.java create mode 100644 shims/spark321/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java create mode 100644 shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala create mode 100644 shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala create mode 100644 shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala create mode 100644 shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala create mode 100644 shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala create mode 100644 shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala create mode 100644 shims/spark321/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala create mode 100644 shims/spark321/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala diff --git a/shims/spark321/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/shims/spark321/src/main/java/org/apache/spark/util/sketch/BloomFilter.java new file mode 100644 index 000000000..2a6e270a9 --- /dev/null +++ b/shims/spark321/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * A Bloom filter is a space-efficient probabilistic data structure that offers an approximate + * containment test with one-sided error: if it claims that an item is contained in it, this + * might be in error, but if it claims that an item is not contained in it, then this is + * definitely true. Currently supported data types include: + * + * The false positive probability ({@code FPP}) of a Bloom filter is defined as the probability that + * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that has + * not actually been put in the {@code BloomFilter}. + * + * The implementation is largely based on the {@code BloomFilter} class from Guava. + */ +public abstract class BloomFilter { + + public enum Version { + /** + * {@code BloomFilter} binary format version 1. All values written in big-endian order: + * + */ + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + int getVersionNumber() { + return versionNumber; + } + } + + /** + * Returns the probability that {@linkplain #mightContain(Object)} erroneously return {@code true} + * for an object that has not actually been put in the {@code BloomFilter}. + * + * Ideally, this number should be close to the {@code fpp} parameter passed in + * {@linkplain #create(long, double)}, or smaller. If it is significantly higher, it is usually + * the case that too many items (more than expected) have been put in the {@code BloomFilter}, + * degenerating it. + */ + public abstract double expectedFpp(); + + /** + * Returns the number of bits in the underlying bit array. + */ + public abstract long bitSize(); + + /** + * Puts an item into this {@code BloomFilter}. Ensures that subsequent invocations of + * {@linkplain #mightContain(Object)} with the same item will always return {@code true}. + * + * @return true if the bloom filter's bits changed as a result of this operation. If the bits + * changed, this is definitely the first time {@code object} has been added to the + * filter. If the bits haven't changed, this might be the first time {@code object} + * has been added to the filter. Note that {@code put(t)} always returns the + * opposite result to what {@code mightContain(t)} would have returned at the time + * it is called. + */ + public abstract boolean put(Object item); + + /** + * A specialized variant of {@link #put(Object)} that only supports {@code String} items. + */ + public abstract boolean putString(String item); + + /** + * A specialized variant of {@link #put(Object)} that only supports {@code long} items. + */ + public abstract boolean putLong(long item); + + /** + * A specialized variant of {@link #put(Object)} that only supports byte array items. + */ + public abstract boolean putBinary(byte[] item); + + /** + * Determines whether a given bloom filter is compatible with this bloom filter. For two + * bloom filters to be compatible, they must have the same bit size. + * + * @param other The bloom filter to check for compatibility. + */ + public abstract boolean isCompatible(BloomFilter other); + + /** + * Combines this bloom filter with another bloom filter by performing a bitwise OR of the + * underlying data. The mutations happen to this instance. Callers must ensure the + * bloom filters are appropriately sized to avoid saturating them. + * + * @param other The bloom filter to combine this bloom filter with. It is not mutated. + * @throws IncompatibleMergeException if {@code isCompatible(other) == false} + */ + public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException; + + /** + * Combines this bloom filter with another bloom filter by performing a bitwise AND of the + * underlying data. The mutations happen to this instance. Callers must ensure the + * bloom filters are appropriately sized to avoid saturating them. + * + * @param other The bloom filter to combine this bloom filter with. It is not mutated. + * @throws IncompatibleMergeException if {@code isCompatible(other) == false} + */ + public abstract BloomFilter intersectInPlace(BloomFilter other) throws IncompatibleMergeException; + + /** + * Returns {@code true} if the element might have been put in this Bloom filter, + * {@code false} if this is definitely not the case. + */ + public abstract boolean mightContain(Object item); + + /** + * A specialized variant of {@link #mightContain(Object)} that only tests {@code String} items. + */ + public abstract boolean mightContainString(String item); + + /** + * A specialized variant of {@link #mightContain(Object)} that only tests {@code long} items. + */ + public abstract boolean mightContainLong(long item); + + /** + * A specialized variant of {@link #mightContain(Object)} that only tests byte array items. + */ + public abstract boolean mightContainBinary(byte[] item); + + /** + * Writes out this {@link BloomFilter} to an output stream in binary format. It is the caller's + * responsibility to close the stream. + */ + public abstract void writeTo(OutputStream out) throws IOException; + + /** + * @return the number of set bits in this {@link BloomFilter}. + */ + public long cardinality() { + throw new UnsupportedOperationException("Not implemented"); + } + + /** + * Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close + * the stream. + */ + public static BloomFilter readFrom(InputStream in) throws IOException { + return BloomFilterImpl.readFrom(in); + } + + /** + * Computes the optimal k (number of hashes per item inserted in Bloom filter), given the + * expected insertions and total number of bits in the Bloom filter. + * + * See http://en.wikipedia.org/wiki/File:Bloom_filter_fp_probability.svg for the formula. + * + * @param n expected insertions (must be positive) + * @param m total number of bits in Bloom filter (must be positive) + */ + private static int optimalNumOfHashFunctions(long n, long m) { + // (m / n) * log(2), but avoid truncation due to division! + return Math.max(1, (int) Math.round((double) m / n * Math.log(2))); + } + + /** + * Computes m (total bits of Bloom filter) which is expected to achieve, for the specified + * expected insertions, the required false positive probability. + * + * See http://en.wikipedia.org/wiki/Bloom_filter#Probability_of_false_positives for the formula. + * + * @param n expected insertions (must be positive) + * @param p false positive rate (must be 0 < p < 1) + */ + private static long optimalNumOfBits(long n, double p) { + return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2))); + } + + static final double DEFAULT_FPP = 0.03; + + /** + * Creates a {@link BloomFilter} with the expected number of insertions and a default expected + * false positive probability of 3%. + * + * Note that overflowing a {@code BloomFilter} with significantly more elements than specified, + * will result in its saturation, and a sharp deterioration of its false positive probability. + */ + public static BloomFilter create(long expectedNumItems) { + return create(expectedNumItems, DEFAULT_FPP); + } + + /** + * Creates a {@link BloomFilter} with the expected number of insertions and expected false + * positive probability. + * + * Note that overflowing a {@code BloomFilter} with significantly more elements than specified, + * will result in its saturation, and a sharp deterioration of its false positive probability. + */ + public static BloomFilter create(long expectedNumItems, double fpp) { + if (fpp <= 0D || fpp >= 1D) { + throw new IllegalArgumentException( + "False positive probability must be within range (0.0, 1.0)" + ); + } + + return create(expectedNumItems, optimalNumOfBits(expectedNumItems, fpp)); + } + + /** + * Creates a {@link BloomFilter} with given {@code expectedNumItems} and {@code numBits}, it will + * pick an optimal {@code numHashFunctions} which can minimize {@code fpp} for the bloom filter. + */ + public static BloomFilter create(long expectedNumItems, long numBits) { + if (expectedNumItems <= 0) { + throw new IllegalArgumentException("Expected insertions must be positive"); + } + + if (numBits <= 0) { + throw new IllegalArgumentException("Number of bits must be positive"); + } + + return new BloomFilterImpl(optimalNumOfHashFunctions(expectedNumItems, numBits), numBits); + } +} diff --git a/shims/spark321/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/shims/spark321/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java new file mode 100644 index 000000000..b053de8bd --- /dev/null +++ b/shims/spark321/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.sketch; + +import java.io.*; + +class BloomFilterImpl extends BloomFilter implements Serializable { + + private int numHashFunctions; + + private BitArray bits; + + BloomFilterImpl(int numHashFunctions, long numBits) { + this(new BitArray(numBits), numHashFunctions); + } + + private BloomFilterImpl(BitArray bits, int numHashFunctions) { + this.bits = bits; + this.numHashFunctions = numHashFunctions; + } + + private BloomFilterImpl() {} + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + + if (other == null || !(other instanceof BloomFilterImpl)) { + return false; + } + + BloomFilterImpl that = (BloomFilterImpl) other; + + return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits); + } + + @Override + public int hashCode() { + return bits.hashCode() * 31 + numHashFunctions; + } + + @Override + public double expectedFpp() { + return Math.pow((double) bits.cardinality() / bits.bitSize(), numHashFunctions); + } + + @Override + public long bitSize() { + return bits.bitSize(); + } + + @Override + public boolean put(Object item) { + if (item instanceof String) { + return putString((String) item); + } else if (item instanceof byte[]) { + return putBinary((byte[]) item); + } else { + return putLong(Utils.integralToLong(item)); + } + } + + @Override + public boolean putString(String item) { + return putBinary(Utils.getBytesFromUTF8String(item)); + } + + @Override + public boolean putBinary(byte[] item) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); + + long bitSize = bits.bitSize(); + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; + } + + @Override + public boolean mightContainString(String item) { + return mightContainBinary(Utils.getBytesFromUTF8String(item)); + } + + @Override + public boolean mightContainBinary(byte[] item) { + int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0); + int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1); + + long bitSize = bits.bitSize(); + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } + + @Override + public boolean putLong(long item) { + // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce n + // hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. + // Note that `CountMinSketch` use a different strategy, it hash the input long element with + // every i to produce n hash values. + // TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here? + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); + + long bitSize = bits.bitSize(); + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; + } + + @Override + public boolean mightContainLong(long item) { + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); + + long bitSize = bits.bitSize(); + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } + + @Override + public boolean mightContain(Object item) { + if (item instanceof String) { + return mightContainString((String) item); + } else if (item instanceof byte[]) { + return mightContainBinary((byte[]) item); + } else { + return mightContainLong(Utils.integralToLong(item)); + } + } + + @Override + public boolean isCompatible(BloomFilter other) { + if (other == null) { + return false; + } + + if (!(other instanceof BloomFilterImpl)) { + return false; + } + + BloomFilterImpl that = (BloomFilterImpl) other; + return this.bitSize() == that.bitSize() && this.numHashFunctions == that.numHashFunctions; + } + + @Override + public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException { + BloomFilterImpl otherImplInstance = checkCompatibilityForMerge(other); + + this.bits.putAll(otherImplInstance.bits); + return this; + } + + @Override + public BloomFilter intersectInPlace(BloomFilter other) throws IncompatibleMergeException { + BloomFilterImpl otherImplInstance = checkCompatibilityForMerge(other); + + this.bits.and(otherImplInstance.bits); + return this; + } + + @Override + public long cardinality() { + return this.bits.cardinality(); + } + + private BloomFilterImpl checkCompatibilityForMerge(BloomFilter other) + throws IncompatibleMergeException { + // Duplicates the logic of `isCompatible` here to provide better error message. + if (other == null) { + throw new IncompatibleMergeException("Cannot merge null bloom filter"); + } + + if (!(other instanceof BloomFilterImpl)) { + throw new IncompatibleMergeException( + "Cannot merge bloom filter of class " + other.getClass().getName() + ); + } + + BloomFilterImpl that = (BloomFilterImpl) other; + + if (this.bitSize() != that.bitSize()) { + throw new IncompatibleMergeException("Cannot merge bloom filters with different bit size"); + } + + if (this.numHashFunctions != that.numHashFunctions) { + throw new IncompatibleMergeException( + "Cannot merge bloom filters with different number of hash functions" + ); + } + return that; + } + + @Override + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(Version.V1.getVersionNumber()); + dos.writeInt(numHashFunctions); + bits.writeTo(dos); + } + + private void readFrom0(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Bloom filter version number (" + version + ")"); + } + + this.numHashFunctions = dis.readInt(); + this.bits = BitArray.readFrom(dis); + } + + public static BloomFilterImpl readFrom(InputStream in) throws IOException { + BloomFilterImpl filter = new BloomFilterImpl(); + filter.readFrom0(in); + return filter; + } + + private void writeObject(ObjectOutputStream out) throws IOException { + writeTo(out); + } + + private void readObject(ObjectInputStream in) throws IOException { + readFrom0(in); + } +} diff --git a/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala new file mode 100644 index 000000000..9a1cf637e --- /dev/null +++ b/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.io.ByteArrayInputStream + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE +import org.apache.spark.sql.types._ +import org.apache.spark.util.sketch.BloomFilter + +/** + * An internal scalar function that returns the membership check result (either true or false) + * for values of `valueExpression` in the Bloom filter represented by `bloomFilterExpression`. + * Not that since the function is "might contain", always returning true regardless is not + * wrong. + * Note that this expression requires that `bloomFilterExpression` is either a constant value or + * an uncorrelated scalar subquery. This is sufficient for the Bloom filter join rewrite. + * + * @param bloomFilterExpression the Binary data of Bloom filter. + * @param valueExpression the Long value to be tested for the membership of `bloomFilterExpression`. + */ +case class BloomFilterMightContain( + bloomFilterExpression: Expression, + valueExpression: Expression) extends BinaryExpression { + + override def nullable: Boolean = true + override def left: Expression = bloomFilterExpression + override def right: Expression = valueExpression + override def prettyName: String = "might_contain" + override def dataType: DataType = BooleanType + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = (left.dataType, right.dataType) match { + case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) | + (BinaryType, LongType) => TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${BinaryType.simpleString} followed by a value with ${LongType.simpleString}, " + + s"but it's [${left.dataType.catalogString}, ${right.dataType.catalogString}].") + } + if (typeCheckResult.isFailure) { + return typeCheckResult + } + bloomFilterExpression match { + case e : Expression if e.foldable => TypeCheckResult.TypeCheckSuccess + case subquery : PlanExpression[_] if !subquery.containsPattern(OUTER_REFERENCE) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"The Bloom filter binary input to $prettyName " + + s"should be either a constant value or a scalar subquery expression") + } + } + + override protected def withNewChildrenInternal( + newBloomFilterExpression: Expression, + newValueExpression: Expression): BloomFilterMightContain = + copy(bloomFilterExpression = newBloomFilterExpression, + valueExpression = newValueExpression) + + // The bloom filter created from `bloomFilterExpression`. + @transient private var bloomFilter: BloomFilter = _ + + override def nullSafeEval(bloomFilterBytes: Any, value: Any): Any = { + if (bloomFilter == null) { + bloomFilter = deserialize(bloomFilterBytes.asInstanceOf[Array[Byte]]) + } + bloomFilter.mightContainLong(value.asInstanceOf[Long]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val thisObj = ctx.addReferenceObj("thisObj", this) + nullSafeCodeGen(ctx, ev, (bloomFilterBytes, value) => { + s"\n${ev.value} = (Boolean) $thisObj.nullSafeEval($bloomFilterBytes, $value);\n" + }) + } + + final def deserialize(bytes: Array[Byte]): BloomFilter = { + val in = new ByteArrayInputStream(bytes) + val bloomFilter = BloomFilter.readFrom(in) + in.close() + bloomFilter + } + +} diff --git a/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala new file mode 100644 index 000000000..86d3d62e1 --- /dev/null +++ b/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.TernaryLike +import org.apache.spark.sql.types._ +import org.apache.spark.util.sketch.BloomFilter + +/** + * An internal aggregate function that creates a Bloom filter from input values. + * + * @param child Child expression of Long values for creating a Bloom filter. + * @param estimatedNumItemsExpression The number of estimated distinct items (optional). + * @param numBitsExpression The number of bits to use (optional). + */ +case class BloomFilterAggregate( + child: Expression, + estimatedNumItemsExpression: Expression, + numBitsExpression: Expression, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) + extends TypedImperativeAggregate[BloomFilter] with TernaryLike[Expression] { + + def this(child: Expression, estimatedNumItemsExpression: Expression, + numBitsExpression: Expression) = { + this(child, estimatedNumItemsExpression, numBitsExpression, 0, 0) + } + + def this(child: Expression, estimatedNumItemsExpression: Expression) = { + this(child, estimatedNumItemsExpression, + // 1 byte per item. + Multiply(estimatedNumItemsExpression, Literal(8L))) + } + + def this(child: Expression) = { + this(child, Literal(BloomFilterAggregate.DEFAULT_EXPECTED_NUM_ITEMS), + Literal(BloomFilterAggregate.DEFAULT_NUM_BITS)) + } + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = (first.dataType, second.dataType, third.dataType) match { + case (_, NullType, _) | (_, _, NullType) => + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as size arguments") + case (LongType, LongType, LongType) => TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been a ${LongType.simpleString} value followed with two ${LongType.simpleString} size " + + s"arguments, but it's [${first.dataType.catalogString}, " + + s"${second.dataType.catalogString}, ${third.dataType.catalogString}]") + } + if (typeCheckResult.isFailure) { + return typeCheckResult + } + if (!estimatedNumItemsExpression.foldable) { + TypeCheckFailure("The estimated number of items provided must be a constant literal") + } else if (estimatedNumItems <= 0L) { + TypeCheckFailure("The estimated number of items must be a positive value " + + s" (current value = $estimatedNumItems)") + } else if (!numBitsExpression.foldable) { + TypeCheckFailure("The number of bits provided must be a constant literal") + } else if (numBits <= 0L) { + TypeCheckFailure("The number of bits must be a positive value " + + s" (current value = $numBits)") + } else { + require(estimatedNumItems <= BloomFilterAggregate.MAX_ALLOWED_NUM_ITEMS) + require(numBits <= BloomFilterAggregate.MAX_NUM_BITS) + TypeCheckSuccess + } + } + override def nullable: Boolean = true + + override def dataType: DataType = BinaryType + + override def prettyName: String = "bloom_filter_agg" + + // Mark as lazy so that `estimatedNumItems` is not evaluated during tree transformation. + private lazy val estimatedNumItems: Long = + Math.min(estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue, + BloomFilterAggregate.MAX_ALLOWED_NUM_ITEMS) + + // Mark as lazy so that `numBits` is not evaluated during tree transformation. + private lazy val numBits: Long = + Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue, + BloomFilterAggregate.MAX_NUM_BITS) + + override def first: Expression = child + + override def second: Expression = estimatedNumItemsExpression + + override def third: Expression = numBitsExpression + + override protected def withNewChildrenInternal(newChild: Expression, + newEstimatedNumItemsExpression: Expression, newNumBitsExpression: Expression) + : BloomFilterAggregate = { + copy(child = newChild, estimatedNumItemsExpression = newEstimatedNumItemsExpression, + numBitsExpression = newNumBitsExpression) + } + + override def createAggregationBuffer(): BloomFilter = { + BloomFilter.create(estimatedNumItems, numBits) + } + + override def update(buffer: BloomFilter, inputRow: InternalRow): BloomFilter = { + val value = child.eval(inputRow) + // Ignore null values. + if (value == null) { + return buffer + } + buffer.putLong(value.asInstanceOf[Long]) + buffer + } + + override def merge(buffer: BloomFilter, other: BloomFilter): BloomFilter = { + buffer.mergeInPlace(other) + } + + override def eval(buffer: BloomFilter): Any = { + if (buffer.cardinality() == 0) { + // There's no set bit in the Bloom filter and hence no not-null value is processed. + return null + } + serialize(buffer) + } + + override def withNewMutableAggBufferOffset(newOffset: Int): BloomFilterAggregate = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): BloomFilterAggregate = + copy(inputAggBufferOffset = newOffset) + + override def serialize(obj: BloomFilter): Array[Byte] = { + BloomFilterAggregate.serde.serialize(obj) + } + + override def deserialize(bytes: Array[Byte]): BloomFilter = { + BloomFilterAggregate.serde.deserialize(bytes) + } +} + +object BloomFilterAggregate { + + val DEFAULT_EXPECTED_NUM_ITEMS: Long = 1000000L // Default 1M distinct items + + val MAX_ALLOWED_NUM_ITEMS: Long = 4000000L // At most 4M distinct items + + val DEFAULT_NUM_BITS: Long = 8388608 // Default 1MB + + val MAX_NUM_BITS: Long = 67108864 // At most 8MB + + /** + * Serializer/Deserializer for class [[BloomFilter]] + * + * This class is thread safe. + */ + class BloomFilterSerDe { + + final def serialize(obj: BloomFilter): Array[Byte] = { + val size = obj.bitSize()/8 + require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size") + val out = new ByteArrayOutputStream(size.intValue()) + obj.writeTo(out) + out.close() + out.toByteArray + } + + final def deserialize(bytes: Array[Byte]): BloomFilter = { + val in = new ByteArrayInputStream(bytes) + val bloomFilter = BloomFilter.readFrom(in) + in.close() + bloomFilter + } + } + + val serde: BloomFilterSerDe = new BloomFilterSerDe +} diff --git a/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala new file mode 100644 index 000000000..c42843764 --- /dev/null +++ b/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -0,0 +1,1950 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.objects + +import java.lang.reflect.{Method, Modifier} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.{Builder, WrappedArray} +import scala.reflect.ClassTag +import scala.util.{Properties, Try} + +import org.apache.commons.lang3.reflect.MethodUtils + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.serializer._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.TernaryLike +import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, _} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +/** + * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]]. + */ +trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInputTypes { + + def arguments: Seq[Expression] + + def propagateNull: Boolean + + protected lazy val needNullCheck: Boolean = needNullCheckForIndex.contains(true) + protected lazy val needNullCheckForIndex: Array[Boolean] = + arguments.map(a => a.nullable && (propagateNull || + ScalaReflection.dataTypeJavaClass(a.dataType).isPrimitive)).toArray + protected lazy val evaluatedArgs: Array[Object] = new Array[Object](arguments.length) + private lazy val boxingFn: Any => Any = + ScalaReflection.typeBoxedJavaMapping + .get(dataType) + .map(cls => v => cls.cast(v)) + .getOrElse(identity) + + + /** + * Prepares codes for arguments. + * + * - generate codes for argument. + * - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments. + * - avoid some of nullability checking which are not needed because the expression is not + * nullable. + * - when needNullCheck == true, short circuit if we found one of arguments is null because + * preparing rest of arguments can be skipped in the case. + * + * @param ctx a [[CodegenContext]] + * @return (code to prepare arguments, argument string, result of argument null check) + */ + def prepareArguments(ctx: CodegenContext): (String, String, ExprValue) = { + + val resultIsNull = if (needNullCheck) { + val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") + JavaCode.isNullGlobal(resultIsNull) + } else { + FalseLiteral + } + val argValues = arguments.map { e => + val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") + argValue + } + + val argCodes = if (needNullCheck) { + val reset = s"$resultIsNull = false;" + val argCodes = arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + val updateResultIsNull = if (needNullCheckForIndex(i)) { + s"$resultIsNull = ${expr.isNull};" + } else { + "" + } + s""" + if (!$resultIsNull) { + ${expr.code} + $updateResultIsNull + ${argValues(i)} = ${expr.value}; + } + """ + } + reset +: argCodes + } else { + arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + s""" + ${expr.code} + ${argValues(i)} = ${expr.value}; + """ + } + } + val argCode = ctx.splitExpressionsWithCurrentInputs(argCodes) + + (argCode, argValues.mkString(", "), resultIsNull) + } + + /** + * Evaluate each argument with a given row, invoke a method with a given object and arguments, + * and cast a return value if the return type can be mapped to a Java Boxed type + * + * @param obj the object for the method to be called. If null, perform s static method call + * @param method the method object to be called + * @param input the row used for evaluating arguments + * @return the return object of a method call + */ + def invoke(obj: Any, method: Method, input: InternalRow): Any = { + var i = 0 + val len = arguments.length + var resultNull = false + while (i < len) { + val result = arguments(i).eval(input).asInstanceOf[Object] + evaluatedArgs(i) = result + resultNull = resultNull || (result == null && needNullCheckForIndex(i)) + i += 1 + } + if (needNullCheck && resultNull) { + // return null if one of arguments is null + null + } else { + val ret = try { + method.invoke(obj, evaluatedArgs: _*) + } catch { + // Re-throw the original exception. + case e: java.lang.reflect.InvocationTargetException if e.getCause != null => + throw e.getCause + } + boxingFn(ret) + } + } + + final def findMethod(cls: Class[_], functionName: String, argClasses: Seq[Class[_]]): Method = { + val method = MethodUtils.getMatchingAccessibleMethod(cls, functionName, argClasses: _*) + if (method == null) { + throw QueryExecutionErrors.methodNotDeclaredError(functionName) + } else { + method + } + } +} + +/** + * Common trait for [[DecodeUsingSerializer]] and [[EncodeUsingSerializer]] + */ +trait SerializerSupport { + /** + * If true, Kryo serialization is used, otherwise the Java one is used + */ + val kryo: Boolean + + /** + * The serializer instance to be used for serialization/deserialization in interpreted execution + */ + lazy val serializerInstance: SerializerInstance = SerializerSupport.newSerializer(kryo) + + /** + * Adds a immutable state to the generated class containing a reference to the serializer. + * @return a string containing the name of the variable referencing the serializer + */ + def addImmutableSerializerIfNeeded(ctx: CodegenContext): String = { + val (serializerInstance, serializerInstanceClass) = { + if (kryo) { + ("kryoSerializer", + classOf[KryoSerializerInstance].getName) + } else { + ("javaSerializer", + classOf[JavaSerializerInstance].getName) + } + } + val newSerializerMethod = s"${classOf[SerializerSupport].getName}$$.MODULE$$.newSerializer" + // Code to initialize the serializer + ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializerInstance, v => + s""" + |$v = ($serializerInstanceClass) $newSerializerMethod($kryo); + """.stripMargin) + serializerInstance + } +} + +object SerializerSupport { + /** + * It creates a new `SerializerInstance` which is either a `KryoSerializerInstance` (is + * `useKryo` is set to `true`) or a `JavaSerializerInstance`. + */ + def newSerializer(useKryo: Boolean): SerializerInstance = { + // try conf from env, otherwise create a new one + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + val s = if (useKryo) { + new KryoSerializer(conf) + } else { + new JavaSerializer(conf) + } + s.newInstance() + } +} + +/** + * Invokes a static function, returning the result. By default, any of the arguments being null + * will result in returning null instead of calling the function. + * + * @param staticObject The target of the static call. This can either be the object itself + * (methods defined on scala objects), or the class object + * (static methods defined in java). + * @param dataType The expected return type of the function call + * @param functionName The name of the method to call. + * @param arguments An optional list of expressions to pass as arguments to the function. + * @param inputTypes A list of data types specifying the input types for the method to be invoked. + * If enabled, it must have the same length as [[arguments]]. In case an input + * type differs from the actual argument type, Spark will try to perform + * type coercion and insert cast whenever necessary before invoking the method. + * The above is disabled if this is empty. + * @param propagateNull When true, and any of the arguments is null, null will be returned instead + * of calling the function. Also note: when this is false but any of the + * arguments is of primitive type and is null, null also will be returned + * without invoking the function. + * @param returnNullable When false, indicating the invoked method will always return + * non-null value. + * @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark + * will not apply certain optimizations such as constant folding. + */ +case class StaticInvoke( + staticObject: Class[_], + dataType: DataType, + functionName: String, + arguments: Seq[Expression] = Nil, + inputTypes: Seq[AbstractDataType] = Nil, + propagateNull: Boolean = true, + returnNullable: Boolean = true, + isDeterministic: Boolean = true) extends InvokeLike { + + val objectName = staticObject.getName.stripSuffix("$") + val cls = if (staticObject.getName == objectName) { + staticObject + } else { + Utils.classForName(objectName) + } + + override def nullable: Boolean = needNullCheck || returnNullable + override def children: Seq[Expression] = arguments + override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic) + + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + @transient lazy val method = findMethod(cls, functionName, argClasses) + + override def eval(input: InternalRow): Any = { + invoke(null, method, input) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) + + val (argCode, argString, resultIsNull) = prepareArguments(ctx) + + val callFunc = s"$objectName.$functionName($argString)" + + val prepareIsNull = if (nullable) { + s"boolean ${ev.isNull} = $resultIsNull;" + } else { + ev.isNull = FalseLiteral + "" + } + + val evaluate = if (returnNullable && !method.getReturnType.isPrimitive) { + if (CodeGenerator.defaultValue(dataType) == "null") { + s""" + ${ev.value} = $callFunc; + ${ev.isNull} = ${ev.value} == null; + """ + } else { + val boxedResult = ctx.freshName("boxedResult") + s""" + ${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc; + ${ev.isNull} = $boxedResult == null; + if (!${ev.isNull}) { + ${ev.value} = $boxedResult; + } + """ + } + } else { + s"${ev.value} = $callFunc;" + } + + val code = code""" + $argCode + $prepareIsNull + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (!$resultIsNull) { + $evaluate + } + """ + ev.copy(code = code) + } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(arguments = newChildren) +} + +/** + * Calls the specified function on an object, optionally passing arguments. If the `targetObject` + * expression evaluates to null then null will be returned. + * + * In some cases, due to erasure, the schema may expect a primitive type when in fact the method + * is returning java.lang.Object. In this case, we will generate code that attempts to unbox the + * value automatically. + * + * @param targetObject An expression that will return the object to call the method on. + * @param functionName The name of the method to call. + * @param dataType The expected return type of the function. + * @param arguments An optional list of expressions, whose evaluation will be passed to the + * function. + * @param methodInputTypes A list of data types specifying the input types for the method to be + * invoked. If enabled, it must have the same length as [[arguments]]. In + * case an input type differs from the actual argument type, Spark will + * try to perform type coercion and insert cast whenever necessary before + * invoking the method. The type coercion is disabled if this is empty. + * @param propagateNull When true, and any of the arguments is null, null will be returned instead + * of calling the function. Also note: when this is false but any of the + * arguments is of primitive type and is null, null also will be returned + * without invoking the function. + * @param returnNullable When false, indicating the invoked method will always return + * non-null value. + * @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark + * will not apply certain optimizations such as constant folding. + */ +case class Invoke( + targetObject: Expression, + functionName: String, + dataType: DataType, + arguments: Seq[Expression] = Nil, + methodInputTypes: Seq[AbstractDataType] = Nil, + propagateNull: Boolean = true, + returnNullable : Boolean = true, + isDeterministic: Boolean = true) extends InvokeLike { + + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + + final override val nodePatterns: Seq[TreePattern] = Seq(INVOKE) + + override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable + override def children: Seq[Expression] = targetObject +: arguments + override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic) + override def inputTypes: Seq[AbstractDataType] = + if (methodInputTypes.nonEmpty) { + Seq(targetObject.dataType) ++ methodInputTypes + } else { + Nil + } + + private lazy val encodedFunctionName = ScalaReflection.encodeFieldNameToIdentifier(functionName) + + @transient lazy val method = targetObject.dataType match { + case ObjectType(cls) => + Some(findMethod(cls, encodedFunctionName, argClasses)) + case _ => None + } + + override def eval(input: InternalRow): Any = { + val obj = targetObject.eval(input) + if (obj == null) { + // return null if obj is null + null + } else { + val invokeMethod = if (method.isDefined) { + method.get + } else { + obj.getClass.getMethod(functionName, argClasses: _*) + } + invoke(obj, invokeMethod, input) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) + val obj = targetObject.genCode(ctx) + + val (argCode, argString, resultIsNull) = prepareArguments(ctx) + + val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive + val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty + + def getFuncResult(resultVal: String, funcCall: String): String = if (needTryCatch) { + s""" + try { + $resultVal = $funcCall; + } catch (Exception e) { + org.apache.spark.unsafe.Platform.throwException(e); + } + """ + } else { + s"$resultVal = $funcCall;" + } + + val evaluate = if (returnPrimitive) { + getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)") + } else { + val funcResult = ctx.freshName("funcResult") + // If the function can return null, we do an extra check to make sure our null bit is still + // set correctly. + val assignResult = if (!returnNullable) { + s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;" + } else { + s""" + if ($funcResult != null) { + ${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult; + } else { + ${ev.isNull} = true; + } + """ + } + s""" + Object $funcResult = null; + ${getFuncResult(funcResult, s"${obj.value}.$encodedFunctionName($argString)")} + $assignResult + """ + } + + val mainEvalCode = + code""" + |$argCode + |${ev.isNull} = $resultIsNull; + |if (!${ev.isNull}) { + | $evaluate + |} + |""".stripMargin + + val evalWithNullCheck = if (targetObject.nullable) { + code""" + |if (!${obj.isNull}) { + | $mainEvalCode + |} + |""".stripMargin + } else { + mainEvalCode + } + + val code = obj.code + code""" + boolean ${ev.isNull} = true; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $evalWithNullCheck + """ + ev.copy(code = code) + } + + override def toString: String = s"$targetObject.$functionName" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Invoke = + copy(targetObject = newChildren.head, arguments = newChildren.tail) +} + +object NewInstance { + def apply( + cls: Class[_], + arguments: Seq[Expression], + dataType: DataType, + propagateNull: Boolean = true): NewInstance = + new NewInstance(cls, arguments, inputTypes = Nil, propagateNull, dataType, None) +} + +/** + * Constructs a new instance of the given class, using the result of evaluating the specified + * expressions as arguments. + * + * @param cls The class to construct. + * @param arguments A list of expression to use as arguments to the constructor. + * @param inputTypes A list of data types specifying the input types for the method to be invoked. + * If enabled, it must have the same length as [[arguments]]. In case an input + * type differs from the actual argument type, Spark will try to perform + * type coercion and insert cast whenever necessary before invoking the method. + * The above is disabled if this is empty. + * @param propagateNull When true, if any of the arguments is null, then null will be returned + * instead of trying to construct the object. Also note: when this is false + * but any of the arguments is of primitive type and is null, null also will + * be returned without constructing the object. + * @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you + * to manually specify the type when the object in question is a valid internal + * representation (i.e. ArrayData) instead of an object. + * @param outerPointer If the object being constructed is an inner class, the outerPointer for the + * containing class must be specified. This parameter is defined as an optional + * function, which allows us to get the outer pointer lazily,and it's useful if + * the inner class is defined in REPL. + */ +case class NewInstance( + cls: Class[_], + arguments: Seq[Expression], + inputTypes: Seq[AbstractDataType], + propagateNull: Boolean, + dataType: DataType, + outerPointer: Option[() => AnyRef]) extends InvokeLike { + private val className = cls.getName + + override def nullable: Boolean = needNullCheck + + override def children: Seq[Expression] = arguments + + final override val nodePatterns: Seq[TreePattern] = Seq(NEW_INSTANCE) + + override lazy val resolved: Boolean = { + // If the class to construct is an inner class, we need to get its outer pointer, or this + // expression should be regarded as unresolved. + // Note that static inner classes (e.g., inner classes within Scala objects) don't need + // outer pointer registration. + val needOuterPointer = + outerPointer.isEmpty && Utils.isMemberClass(cls) && !Modifier.isStatic(cls.getModifiers) + childrenResolved && !needOuterPointer + } + + @transient private lazy val constructor: (Seq[AnyRef]) => Any = { + val paramTypes = ScalaReflection.expressionJavaClasses(arguments) + val getConstructor = (paramClazz: Seq[Class[_]]) => { + ScalaReflection.findConstructor(cls, paramClazz).getOrElse { + throw QueryExecutionErrors.constructorNotFoundError(cls.toString) + } + } + outerPointer.map { p => + val outerObj = p() + val c = getConstructor(outerObj.getClass +: paramTypes) + (args: Seq[AnyRef]) => { + c(outerObj +: args) + } + }.getOrElse { + val c = getConstructor(paramTypes) + (args: Seq[AnyRef]) => { + c(args) + } + } + } + + override def eval(input: InternalRow): Any = { + val argValues = arguments.map(_.eval(input)) + constructor(argValues.map(_.asInstanceOf[AnyRef])) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) + + val (argCode, argString, resultIsNull) = prepareArguments(ctx) + + val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) + + ev.isNull = resultIsNull + + val constructorCall = cls.getConstructors.size match { + // If there are no constructors, the `new` method will fail. In + // this case we can try to call the apply method constructor + // that might be defined on the companion object. + case 0 => s"$className$$.MODULE$$.apply($argString)" + case _ => outer.map { gen => + s"${gen.value}.new ${Utils.getSimpleName(cls)}($argString)" + }.getOrElse { + s"new $className($argString)" + } + } + + val code = code""" + $argCode + ${outer.map(_.code).getOrElse("")} + final $javaType ${ev.value} = ${ev.isNull} ? + ${CodeGenerator.defaultValue(dataType)} : $constructorCall; + """ + ev.copy(code = code) + } + + override def toString: String = s"newInstance($cls)" + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): NewInstance = + copy(arguments = newChildren) +} + +/** + * Given an expression that returns on object of type `Option[_]`, this expression unwraps the + * option into the specified Spark SQL datatype. In the case of `None`, the nullbit is set instead. + * + * @param dataType The expected unwrapped option type. + * @param child An expression that returns an `Option` + */ +case class UnwrapOption( + dataType: DataType, + child: Expression) extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil + + override def eval(input: InternalRow): Any = { + val inputObject = child.eval(input) + if (inputObject == null) { + null + } else { + inputObject.asInstanceOf[Option[_]].orNull + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) + val inputObject = child.genCode(ctx) + + val code = inputObject.code + code""" + final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); + $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : + (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); + """ + ev.copy(code = code) + } + + override protected def withNewChildInternal(newChild: Expression): UnwrapOption = + copy(child = newChild) +} + +/** + * Converts the result of evaluating `child` into an option, checking both the isNull bit and + * (in the case of reference types) equality with null. + * + * @param child The expression to evaluate and wrap. + * @param optType The type of this option. + */ +case class WrapOption(child: Expression, optType: DataType) + extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { + + override def dataType: DataType = ObjectType(classOf[Option[_]]) + + override def nullable: Boolean = false + + override def inputTypes: Seq[AbstractDataType] = optType :: Nil + + override def eval(input: InternalRow): Any = Option(child.eval(input)) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val inputObject = child.genCode(ctx) + + val code = inputObject.code + code""" + scala.Option ${ev.value} = + ${inputObject.isNull} ? + scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); + """ + ev.copy(code = code, isNull = FalseLiteral) + } + + override protected def withNewChildInternal(newChild: Expression): WrapOption = + copy(child = newChild) +} + +object LambdaVariable { + private val curId = new java.util.concurrent.atomic.AtomicLong() + + // Returns the codegen-ed `LambdaVariable` and add it to mutable states, so that it can be + // accessed anywhere in the generated code. + def prepareLambdaVariable(ctx: CodegenContext, variable: LambdaVariable): ExprCode = { + val variableCode = variable.genCode(ctx) + assert(variableCode.code.isEmpty) + + ctx.addMutableState( + CodeGenerator.javaType(variable.dataType), + variableCode.value, + forceInline = true, + useFreshName = false) + + if (variable.nullable) { + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, + variableCode.isNull, + forceInline = true, + useFreshName = false) + } + + variableCode + } +} + +/** + * A placeholder for the loop variable used in [[MapObjects]]. This should never be constructed + * manually, but will instead be passed into the provided lambda function. + */ +// TODO: Merge this and `NamedLambdaVariable`. +case class LambdaVariable( + name: String, + dataType: DataType, + nullable: Boolean, + id: Long = LambdaVariable.curId.incrementAndGet) extends LeafExpression with NonSQLExpression { + + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable) + + final override val nodePatterns: Seq[TreePattern] = Seq(LAMBDA_VARIABLE) + + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. + override def eval(input: InternalRow): Any = { + assert(input.numFields == 1, + "The input row of interpreted LambdaVariable should have only 1 field.") + accessor(input, 0) + } + + override def genCode(ctx: CodegenContext): ExprCode = { + // If `LambdaVariable` IDs are reassigned by the `ReassignLambdaVariableID` rule, the IDs will + // all be negative. + val suffix = "lambda_variable_" + math.abs(id) + val isNull = if (nullable) { + JavaCode.isNullVariable(s"isNull_${name}_$suffix") + } else { + FalseLiteral + } + val value = JavaCode.variable(s"value_${name}_$suffix", dataType) + ExprCode(isNull, value) + } + + // This won't be called as `genCode` is overrided, just overriding it to make + // `LambdaVariable` non-abstract. + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev +} + +/** + * When constructing [[MapObjects]], the element type must be given, which may not be available + * before analysis. This class acts like a placeholder for [[MapObjects]], and will be replaced by + * [[MapObjects]] during analysis after the input data is resolved. + * Note that, ideally we should not serialize and send unresolved expressions to executors, but + * users may accidentally do this(e.g. mistakenly reference an encoder instance when implementing + * Aggregator). Here we mark `function` as transient because it may reference scala Type, which is + * not serializable. Then even users mistakenly reference unresolved expression and serialize it, + * it's just a performance issue(more network traffic), and will not fail. + */ +case class UnresolvedMapObjects( + @transient function: Expression => Expression, + child: Expression, + customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable { + override lazy val resolved = false + + override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { + throw QueryExecutionErrors.customCollectionClsNotResolvedError + } + + override protected def withNewChildInternal(newChild: Expression): UnresolvedMapObjects = + copy(child = newChild) +} + +object MapObjects { + /** + * Construct an instance of MapObjects case class. + * + * @param function The function applied on the collection elements. + * @param inputData An expression that when evaluated returns a collection object. + * @param elementType The data type of elements in the collection. + * @param elementNullable When false, indicating elements in the collection are always + * non-null value. + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) + */ + def apply( + function: Expression => Expression, + inputData: Expression, + elementType: DataType, + elementNullable: Boolean = true, + customCollectionCls: Option[Class[_]] = None): MapObjects = { + // UnresolvedMapObjects does not serialize its 'function' field. + // If an array expression or array Encoder is not correctly resolved before + // serialization, this exception condition may occur. + require(function != null, + "MapObjects applied with a null function. " + + "Likely cause is failure to resolve an array expression or encoder. " + + "(See UnresolvedMapObjects)") + val loopVar = LambdaVariable("MapObject", elementType, elementNullable) + MapObjects(loopVar, function(loopVar), inputData, customCollectionCls) + } +} + +/** + * Applies the given expression to every element of a collection of items, returning the result + * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda + * function is expressed using catalyst expressions. + * + * The type of the result is determined as follows: + * - ArrayType - when customCollectionCls is None + * - ObjectType(collection) - when customCollectionCls contains a collection class + * + * The following collection ObjectTypes are currently supported on input: + * Seq, Array, ArrayData, java.util.List + * + * @param loopVar the [[LambdaVariable]] expression representing the loop variable that used to + * iterate the collection, and used as input for the `lambdaFunction`. + * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function + * to handle collection elements. + * @param inputData An expression that when evaluated returns a collection object. + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) + */ +case class MapObjects private( + loopVar: LambdaVariable, + lambdaFunction: Expression, + inputData: Expression, + customCollectionCls: Option[Class[_]]) extends Expression with NonSQLExpression + with TernaryLike[Expression] { + + override def nullable: Boolean = inputData.nullable + + override def first: Expression = loopVar + override def second: Expression = lambdaFunction + override def third: Expression = inputData + + final override val nodePatterns: Seq[TreePattern] = Seq(MAP_OBJECTS) + + // The data with UserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + lazy private val inputDataType = inputData.dataType match { + case u: UserDefinedType[_] => u.sqlType + case _ => inputData.dataType + } + + private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = { + val row = new GenericInternalRow(1) + inputCollection.toIterator.map { element => + row.update(0, element) + lambdaFunction.eval(row) + } + } + + private lazy val convertToSeq: Any => Seq[_] = inputDataType match { + case ObjectType(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) => + _.asInstanceOf[scala.collection.Seq[_]].toSeq + case ObjectType(cls) if cls.isArray => + _.asInstanceOf[Array[_]].toSeq + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + _.asInstanceOf[java.util.List[_]].asScala.toSeq + case ObjectType(cls) if cls == classOf[Object] => + (inputCollection) => { + if (inputCollection.getClass.isArray) { + inputCollection.asInstanceOf[Array[_]].toSeq + } else { + inputCollection.asInstanceOf[Seq[_]] + } + } + case ArrayType(et, _) => + _.asInstanceOf[ArrayData].toSeq[Any](et) + } + + private lazy val mapElements: Seq[_] => Any = customCollectionCls match { + case Some(cls) if classOf[WrappedArray[_]].isAssignableFrom(cls) => + // Scala WrappedArray + inputCollection => WrappedArray.make(executeFuncOnCollection(inputCollection).toArray) + case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) => + // Scala sequence + executeFuncOnCollection(_).toSeq + case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + // Scala set + executeFuncOnCollection(_).toSet + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + // Specifying non concrete implementations of `java.util.List` + executeFuncOnCollection(_).toSeq.asJava + } else { + val constructors = cls.getConstructors() + val intParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int] + } + val noParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 0 + } + + val constructor = intParamConstructor.map { intConstructor => + (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object]) + }.getOrElse { + (_: Int) => noParamConstructor.get.newInstance() + } + + // Specifying concrete implementations of `java.util.List` + (inputs) => { + val results = executeFuncOnCollection(inputs) + val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]] + results.foreach(builder.add(_)) + builder + } + } + case None => + // array + x => new GenericArrayData(executeFuncOnCollection(x).toArray) + case Some(cls) => + throw QueryExecutionErrors.classUnsupportedByMapObjectsError(cls) + } + + override def eval(input: InternalRow): Any = { + val inputCollection = inputData.eval(input) + + if (inputCollection == null) { + return null + } + mapElements(convertToSeq(inputCollection)) + } + + override def dataType: DataType = + customCollectionCls.map(ObjectType.apply).getOrElse( + ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val elementJavaType = CodeGenerator.javaType(loopVar.dataType) + val loopVarCode = LambdaVariable.prepareLambdaVariable(ctx, loopVar) + val genInputData = inputData.genCode(ctx) + val genFunction = lambdaFunction.genCode(ctx) + val dataLength = ctx.freshName("dataLength") + val convertedArray = ctx.freshName("convertedArray") + val loopIndex = ctx.freshName("loopIndex") + + val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType) + + // Because of the way Java defines nested arrays, we have to handle the syntax specially. + // Specifically, we have to insert the [$dataLength] in between the type and any extra nested + // array declarations (i.e. new String[1][]). + val arrayConstructor = if (convertedType contains "[]") { + val rawType = convertedType.takeWhile(_ != '[') + val arrayPart = convertedType.reverse.takeWhile(c => c == '[' || c == ']').reverse + s"new $rawType[$dataLength]$arrayPart" + } else { + s"new $convertedType[$dataLength]" + } + + // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type + // of input collection at runtime for this case. + val seq = ctx.freshName("seq") + val array = ctx.freshName("array") + val determineCollectionType = inputData.dataType match { + case ObjectType(cls) if cls == classOf[Object] => + val seqClass = classOf[scala.collection.Seq[_]].getName + s""" + $seqClass $seq = null; + $elementJavaType[] $array = null; + if (${genInputData.value}.getClass().isArray()) { + $array = ($elementJavaType[]) ${genInputData.value}; + } else { + $seq = ($seqClass) ${genInputData.value}; + } + """ + case _ => "" + } + + // `MapObjects` generates a while loop to traverse the elements of the input collection. We + // need to take care of Seq and List because they may have O(n) complexity for indexed accessing + // like `list.get(1)`. Here we use Iterator to traverse Seq and List. + val (getLength, prepareLoop, getLoopVar) = inputDataType match { + case ObjectType(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) => + val it = ctx.freshName("it") + ( + s"${genInputData.value}.size()", + s"scala.collection.Iterator $it = ${genInputData.value}.toIterator();", + s"$it.next()" + ) + case ObjectType(cls) if cls.isArray => + ( + s"${genInputData.value}.length", + "", + s"${genInputData.value}[$loopIndex]" + ) + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + val it = ctx.freshName("it") + ( + s"${genInputData.value}.size()", + s"java.util.Iterator $it = ${genInputData.value}.iterator();", + s"$it.next()" + ) + case ArrayType(et, _) => + ( + s"${genInputData.value}.numElements()", + "", + CodeGenerator.getValue(genInputData.value, et, loopIndex) + ) + case ObjectType(cls) if cls == classOf[Object] => + val it = ctx.freshName("it") + ( + s"$seq == null ? $array.length : $seq.size()", + s"scala.collection.Iterator $it = $seq == null ? null : $seq.toIterator();", + s"$it == null ? $array[$loopIndex] : $it.next()" + ) + } + + // Make a copy of the data if it's unsafe-backed + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = + s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" + val genFunctionValue: String = lambdaFunction.dataType match { + case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) + case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) + case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case _ => genFunction.value + } + + val loopNullCheck = if (loopVar.nullable) { + inputDataType match { + case _: ArrayType => s"${loopVarCode.isNull} = ${genInputData.value}.isNullAt($loopIndex);" + case _ => s"${loopVarCode.isNull} = ${loopVarCode.value} == null;" + } + } else { + "" + } + + val (initCollection, addElement, getResult): (String, String => String, String) = + customCollectionCls match { + case Some(cls) if classOf[WrappedArray[_]].isAssignableFrom(cls) => + def doCodeGenForScala212 = { + // WrappedArray in Scala 2.12 + val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" + val builder = ctx.freshName("collectionBuilder") + ( + s""" + ${classOf[Builder[_, _]].getName} $builder = $getBuilder; + $builder.sizeHint($dataLength); + """, + (genValue: String) => s"$builder.$$plus$$eq($genValue);", + s"(${cls.getName}) ${classOf[WrappedArray[_]].getName}$$." + + s"MODULE$$.make(((${classOf[IndexedSeq[_]].getName})$builder" + + s".result()).toArray(scala.reflect.ClassTag$$.MODULE$$.Object()));" + ) + } + + def doCodeGenForScala213 = { + // In Scala 2.13, WrappedArray is mutable.ArraySeq and newBuilder method need + // a ClassTag type construction parameter + val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder(" + + s"scala.reflect.ClassTag$$.MODULE$$.Object())" + val builder = ctx.freshName("collectionBuilder") + ( + s""" + ${classOf[Builder[_, _]].getName} $builder = $getBuilder; + $builder.sizeHint($dataLength); + """, + (genValue: String) => s"$builder.$$plus$$eq($genValue);", + s"(${cls.getName})$builder.result();" + ) + } + + val scalaVersion = Properties.versionNumberString + if (scalaVersion.startsWith("2.12")) { + doCodeGenForScala212 + } else { + doCodeGenForScala213 + } + case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) || + classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + // Scala sequence or set + val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" + val builder = ctx.freshName("collectionBuilder") + ( + s""" + ${classOf[Builder[_, _]].getName} $builder = $getBuilder; + $builder.sizeHint($dataLength); + """, + (genValue: String) => s"$builder.$$plus$$eq($genValue);", + s"(${cls.getName}) $builder.result();" + ) + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + val builder = ctx.freshName("collectionBuilder") + ( + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + s"${cls.getName} $builder = new java.util.ArrayList($dataLength);" + } else { + val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("") + s"${cls.getName} $builder = new ${cls.getName}($param);" + }, + (genValue: String) => s"$builder.add($genValue);", + s"$builder;" + ) + case _ => + // array + ( + s""" + $convertedType[] $convertedArray = null; + $convertedArray = $arrayConstructor; + """, + (genValue: String) => s"$convertedArray[$loopIndex] = $genValue;", + s"new ${classOf[GenericArrayData].getName}($convertedArray);" + ) + } + + val code = genInputData.code + code""" + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + + if (!${genInputData.isNull}) { + $determineCollectionType + int $dataLength = $getLength; + $initCollection + + int $loopIndex = 0; + $prepareLoop + while ($loopIndex < $dataLength) { + ${loopVarCode.value} = ($elementJavaType) ($getLoopVar); + $loopNullCheck + + ${genFunction.code} + if (${genFunction.isNull}) { + ${addElement("null")} + } else { + ${addElement(genFunctionValue)} + } + + $loopIndex += 1; + } + + ${ev.value} = $getResult + } + """ + ev.copy(code = code, isNull = genInputData.isNull) + } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy( + loopVar = newFirst.asInstanceOf[LambdaVariable], + lambdaFunction = newSecond, + inputData = newThird) +} + +/** + * Similar to [[UnresolvedMapObjects]], this is a placeholder of [[CatalystToExternalMap]]. + * + * @param child An expression that when evaluated returns a map object. + * @param keyFunction The function applied on the key collection elements. + * @param valueFunction The function applied on the value collection elements. + * @param collClass The type of the resulting collection. + */ +case class UnresolvedCatalystToExternalMap( + child: Expression, + @transient keyFunction: Expression => Expression, + @transient valueFunction: Expression => Expression, + collClass: Class[_]) extends UnaryExpression with Unevaluable { + + override lazy val resolved = false + + override def dataType: DataType = ObjectType(collClass) + + override protected def withNewChildInternal( + newChild: Expression): UnresolvedCatalystToExternalMap = copy(child = newChild) +} + +object CatalystToExternalMap { + def apply(u: UnresolvedCatalystToExternalMap): CatalystToExternalMap = { + val mapType = u.child.dataType.asInstanceOf[MapType] + val keyLoopVar = LambdaVariable( + "CatalystToExternalMap_key", mapType.keyType, nullable = false) + val valueLoopVar = LambdaVariable( + "CatalystToExternalMap_value", mapType.valueType, mapType.valueContainsNull) + CatalystToExternalMap( + keyLoopVar, u.keyFunction(keyLoopVar), + valueLoopVar, u.valueFunction(valueLoopVar), + u.child, u.collClass) + } +} + +/** + * Expression used to convert a Catalyst Map to an external Scala Map. + * The collection is constructed using the associated builder, obtained by calling `newBuilder` + * on the collection's companion object. + * + * @param keyLoopVar the [[LambdaVariable]] expression representing the loop variable that is used + * when iterating over the key collection, and which is used as input for the + * `keyLambdaFunction`. + * @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as + * a lambda function to handle collection elements. + * @param valueLoopVar the [[LambdaVariable]] expression representing the loop variable that is used + * when iterating over the value collection, and which is used as input for the + * `valueLambdaFunction`. + * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as + * a lambda function to handle collection elements. + * @param inputData An expression that when evaluated returns a map object. + * @param collClass The type of the resulting collection. + */ +case class CatalystToExternalMap private( + keyLoopVar: LambdaVariable, + keyLambdaFunction: Expression, + valueLoopVar: LambdaVariable, + valueLambdaFunction: Expression, + inputData: Expression, + collClass: Class[_]) extends Expression with NonSQLExpression { + + override def nullable: Boolean = inputData.nullable + + override def children: Seq[Expression] = Seq( + keyLoopVar, keyLambdaFunction, valueLoopVar, valueLambdaFunction, inputData) + + private lazy val inputMapType = inputData.dataType.asInstanceOf[MapType] + + private lazy val (newMapBuilderMethod, moduleField) = { + val clazz = Utils.classForName(collClass.getCanonicalName + "$") + (clazz.getMethod("newBuilder"), clazz.getField("MODULE$").get(null)) + } + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + newMapBuilderMethod.invoke(moduleField).asInstanceOf[Builder[AnyRef, AnyRef]] + } + + override def eval(input: InternalRow): Any = { + val result = inputData.eval(input).asInstanceOf[MapData] + if (result != null) { + val builder = newMapBuilder() + builder.sizeHint(result.numElements()) + val keyArray = result.keyArray() + val valueArray = result.valueArray() + val row = new GenericInternalRow(1) + var i = 0 + while (i < result.numElements()) { + row.update(0, keyArray.get(i, inputMapType.keyType)) + val key = keyLambdaFunction.eval(row) + row.update(0, valueArray.get(i, inputMapType.valueType)) + val value = valueLambdaFunction.eval(row) + builder += Tuple2(key, value) + i += 1 + } + builder.result() + } else { + null + } + } + + override def dataType: DataType = ObjectType(collClass) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val keyCode = LambdaVariable.prepareLambdaVariable(ctx, keyLoopVar) + val valueCode = LambdaVariable.prepareLambdaVariable(ctx, valueLoopVar) + val keyElementJavaType = CodeGenerator.javaType(keyLoopVar.dataType) + val genKeyFunction = keyLambdaFunction.genCode(ctx) + val valueElementJavaType = CodeGenerator.javaType(valueLoopVar.dataType) + val genValueFunction = valueLambdaFunction.genCode(ctx) + val genInputData = inputData.genCode(ctx) + val dataLength = ctx.freshName("dataLength") + val loopIndex = ctx.freshName("loopIndex") + val tupleLoopValue = ctx.freshName("tupleLoopValue") + val builderValue = ctx.freshName("builderValue") + + val keyArray = ctx.freshName("keyArray") + val valueArray = ctx.freshName("valueArray") + val getKeyLoopVar = CodeGenerator.getValue(keyArray, keyLoopVar.dataType, loopIndex) + val getValueLoopVar = CodeGenerator.getValue(valueArray, valueLoopVar.dataType, loopIndex) + + // Make a copy of the data if it's unsafe-backed + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = + s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value" + def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) = + lambdaFunction.dataType match { + case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) + case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) + case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case _ => genFunction.value + } + val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) + val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) + + val valueLoopNullCheck = if (valueLoopVar.nullable) { + s"${valueCode.isNull} = $valueArray.isNullAt($loopIndex);" + } else { + "" + } + + val builderClass = classOf[Builder[_, _]].getName + val constructBuilder = s""" + $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder(); + $builderValue.sizeHint($dataLength); + """ + + val tupleClass = classOf[(_, _)].getName + val appendToBuilder = s""" + $tupleClass $tupleLoopValue; + + if (${genValueFunction.isNull}) { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null); + } else { + $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue); + } + + $builderValue.$$plus$$eq($tupleLoopValue); + """ + val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" + + val code = genInputData.code + code""" + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + + if (!${genInputData.isNull}) { + int $dataLength = ${genInputData.value}.numElements(); + $constructBuilder + ArrayData $keyArray = ${genInputData.value}.keyArray(); + ArrayData $valueArray = ${genInputData.value}.valueArray(); + + int $loopIndex = 0; + while ($loopIndex < $dataLength) { + ${keyCode.value} = ($keyElementJavaType) ($getKeyLoopVar); + ${valueCode.value} = ($valueElementJavaType) ($getValueLoopVar); + $valueLoopNullCheck + + ${genKeyFunction.code} + ${genValueFunction.code} + + $appendToBuilder + + $loopIndex += 1; + } + + $getBuilderResult + } + """ + ev.copy(code = code, isNull = genInputData.isNull) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CatalystToExternalMap = + copy( + keyLoopVar = newChildren(0).asInstanceOf[LambdaVariable], + keyLambdaFunction = newChildren(1), + valueLoopVar = newChildren(2).asInstanceOf[LambdaVariable], + valueLambdaFunction = newChildren(3), + inputData = newChildren(4)) +} + +object ExternalMapToCatalyst { + def apply( + inputMap: Expression, + keyType: DataType, + keyConverter: Expression => Expression, + keyNullable: Boolean, + valueType: DataType, + valueConverter: Expression => Expression, + valueNullable: Boolean): ExternalMapToCatalyst = { + val keyLoopVar = LambdaVariable("ExternalMapToCatalyst_key", keyType, keyNullable) + val valueLoopVar = LambdaVariable("ExternalMapToCatalyst_value", valueType, valueNullable) + ExternalMapToCatalyst( + keyLoopVar, + keyConverter(keyLoopVar), + valueLoopVar, + valueConverter(valueLoopVar), + inputMap) + } +} + +/** + * Converts a Scala/Java map object into catalyst format, by applying the key/value converter when + * iterate the map. + * + * @param keyLoopVar the [[LambdaVariable]] expression representing the loop variable that is used + * when iterating over the key collection, and which is used as input for the + * `keyConverter`. + * @param keyConverter A function that take the `key` as input, and converts it to catalyst format. + * @param valueLoopVar the [[LambdaVariable]] expression representing the loop variable that is used + * when iterating over the value collection, and which is used as input for the + * `valueConverter`. + * @param valueConverter A function that take the `value` as input, and converts it to catalyst + * format. + * @param inputData An expression that when evaluated returns the input map object. + */ +case class ExternalMapToCatalyst private( + keyLoopVar: LambdaVariable, + keyConverter: Expression, + valueLoopVar: LambdaVariable, + valueConverter: Expression, + inputData: Expression) + extends Expression with NonSQLExpression { + + override def foldable: Boolean = false + + override def nullable: Boolean = inputData.nullable + + override def children: Seq[Expression] = Seq( + keyLoopVar, keyConverter, valueLoopVar, valueConverter, inputData) + + override def dataType: MapType = MapType( + keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) + + private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = { + val rowBuffer = InternalRow.fromSeq(Array[Any](1)) + def rowWrapper(data: Any): InternalRow = { + rowBuffer.update(0, data) + rowBuffer + } + + inputData.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[java.util.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + val iter = data.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + val (key, value) = (entry.getKey, entry.getValue) + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw QueryExecutionErrors.nullAsMapKeyNotAllowedError + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 + } + (keys, values) + } + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[scala.collection.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + var i = 0 + for ((key, value) <- data) { + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw QueryExecutionErrors.nullAsMapKeyNotAllowedError + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 + } + (keys, values) + } + } + } + + override def eval(input: InternalRow): Any = { + val result = inputData.eval(input) + if (result != null) { + val (keys, values) = mapCatalystConverter(result) + new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) + } else { + null + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val inputMap = inputData.genCode(ctx) + val genKeyConverter = keyConverter.genCode(ctx) + val genValueConverter = valueConverter.genCode(ctx) + val length = ctx.freshName("length") + val index = ctx.freshName("index") + val convertedKeys = ctx.freshName("convertedKeys") + val convertedValues = ctx.freshName("convertedValues") + val entry = ctx.freshName("entry") + val entries = ctx.freshName("entries") + + val keyJavaType = CodeGenerator.javaType(keyLoopVar.dataType) + val valueJavaType = CodeGenerator.javaType(valueLoopVar.dataType) + val keyCode = LambdaVariable.prepareLambdaVariable(ctx, keyLoopVar) + val valueCode = LambdaVariable.prepareLambdaVariable(ctx, valueLoopVar) + + val (defineEntries, defineKeyValue) = inputData.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + val javaIteratorCls = classOf[java.util.Iterator[_]].getName + val javaMapEntryCls = classOf[java.util.Map.Entry[_, _]].getName + + val defineEntries = + s"final $javaIteratorCls $entries = ${inputMap.value}.entrySet().iterator();" + + val defineKeyValue = + s""" + final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); + ${keyCode.value} = (${CodeGenerator.boxedType(keyJavaType)}) $entry.getKey(); + ${valueCode.value} = (${CodeGenerator.boxedType(valueJavaType)}) $entry.getValue(); + """ + + defineEntries -> defineKeyValue + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + val scalaIteratorCls = classOf[Iterator[_]].getName + val scalaMapEntryCls = classOf[Tuple2[_, _]].getName + + val defineEntries = s"final $scalaIteratorCls $entries = ${inputMap.value}.iterator();" + + val defineKeyValue = + s""" + final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); + ${keyCode.value} = (${CodeGenerator.boxedType(keyJavaType)}) $entry._1(); + ${valueCode.value} = (${CodeGenerator.boxedType(valueJavaType)}) $entry._2(); + """ + + defineEntries -> defineKeyValue + } + + val keyNullCheck = if (keyLoopVar.nullable) { + s"${keyCode.isNull} = ${keyCode.value} == null;" + } else { + "" + } + + val valueNullCheck = if (valueLoopVar.nullable) { + s"${valueCode.isNull} = ${valueCode.value} == null;" + } else { + "" + } + + val arrayCls = classOf[GenericArrayData].getName + val mapCls = classOf[ArrayBasedMapData].getName + val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) + val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) + val code = inputMap.code + + code""" + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (!${inputMap.isNull}) { + final int $length = ${inputMap.value}.size(); + final Object[] $convertedKeys = new Object[$length]; + final Object[] $convertedValues = new Object[$length]; + int $index = 0; + $defineEntries + while($entries.hasNext()) { + $defineKeyValue + $keyNullCheck + $valueNullCheck + + ${genKeyConverter.code} + if (${genKeyConverter.isNull}) { + throw QueryExecutionErrors.nullAsMapKeyNotAllowedError(); + } else { + $convertedKeys[$index] = ($convertedKeyType) ${genKeyConverter.value}; + } + + ${genValueConverter.code} + if (${genValueConverter.isNull}) { + $convertedValues[$index] = null; + } else { + $convertedValues[$index] = ($convertedValueType) ${genValueConverter.value}; + } + + $index++; + } + + ${ev.value} = new $mapCls(new $arrayCls($convertedKeys), new $arrayCls($convertedValues)); + } + """ + ev.copy(code = code, isNull = inputMap.isNull) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): ExternalMapToCatalyst = + copy( + keyLoopVar = newChildren(0).asInstanceOf[LambdaVariable], + keyConverter = newChildren(1), + valueLoopVar = newChildren(2).asInstanceOf[LambdaVariable], + valueConverter = newChildren(3), + inputData = newChildren(4)) +} + +/** + * Constructs a new external row, using the result of evaluating the specified expressions + * as content. + * + * @param children A list of expression to use as content of the external row. + */ +case class CreateExternalRow(children: Seq[Expression], schema: StructType) + extends Expression with NonSQLExpression { + + override def dataType: DataType = ObjectType(classOf[Row]) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + val values = children.map(_.eval(input)).toArray + new GenericRowWithSchema(values, schema) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val rowClass = classOf[GenericRowWithSchema].getName + val values = ctx.freshName("values") + + val childrenCodes = children.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) + s""" + |${eval.code} + |if (${eval.isNull}) { + | $values[$i] = null; + |} else { + | $values[$i] = ${eval.value}; + |} + """.stripMargin + } + + val childrenCode = ctx.splitExpressionsWithCurrentInputs( + expressions = childrenCodes, + funcName = "createExternalRow", + extraArguments = "Object[]" -> values :: Nil) + val schemaField = ctx.addReferenceObj("schema", schema) + + val code = + code""" + |Object[] $values = new Object[${children.size}]; + |$childrenCode + |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); + """.stripMargin + ev.copy(code = code, isNull = FalseLiteral) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CreateExternalRow = copy(children = newChildren) +} + +/** + * Serializes an input object using a generic serializer (Kryo or Java). + * + * @param kryo if true, use Kryo. Otherwise, use Java. + */ +case class EncodeUsingSerializer(child: Expression, kryo: Boolean) + extends UnaryExpression with NonSQLExpression with SerializerSupport { + + override def nullSafeEval(input: Any): Any = { + serializerInstance.serialize(input).array() + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val serializer = addImmutableSerializerIfNeeded(ctx) + // Code to serialize. + val input = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val serialize = s"$serializer.serialize(${input.value}, null).array()" + + val code = input.code + code""" + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; + """ + ev.copy(code = code, isNull = input.isNull) + } + + override def dataType: DataType = BinaryType + + override protected def withNewChildInternal(newChild: Expression): EncodeUsingSerializer = + copy(child = newChild) +} + +/** + * Serializes an input object using a generic serializer (Kryo or Java). Note that the ClassTag + * is not an implicit parameter because TreeNode cannot copy implicit parameters. + * + * @param kryo if true, use Kryo. Otherwise, use Java. + */ +case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) + extends UnaryExpression with NonSQLExpression with SerializerSupport { + + override def nullSafeEval(input: Any): Any = { + val inputBytes = java.nio.ByteBuffer.wrap(input.asInstanceOf[Array[Byte]]) + serializerInstance.deserialize(inputBytes) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val serializer = addImmutableSerializerIfNeeded(ctx) + // Code to deserialize. + val input = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val deserialize = + s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" + + val code = input.code + code""" + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; + """ + ev.copy(code = code, isNull = input.isNull) + } + + override def dataType: DataType = ObjectType(tag.runtimeClass) + + override protected def withNewChildInternal(newChild: Expression): DecodeUsingSerializer[T] = + copy(child = newChild) +} + +/** + * Initialize a Java Bean instance by setting its field values via setters. + */ +case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression]) + extends Expression with NonSQLExpression { + + override def nullable: Boolean = beanInstance.nullable + override def children: Seq[Expression] = beanInstance +: setters.values.toSeq + override def dataType: DataType = beanInstance.dataType + + private lazy val resolvedSetters = { + assert(beanInstance.dataType.isInstanceOf[ObjectType]) + + val ObjectType(beanClass) = beanInstance.dataType + setters.map { + case (name, expr) => + // Looking for known type mapping. + // But also looking for general `Object`-type parameter for generic methods. + val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) + val methods = paramTypes.flatMap { fieldClass => + try { + Some(beanClass.getDeclaredMethod(name, fieldClass)) + } catch { + case e: NoSuchMethodException => None + } + } + if (methods.isEmpty) { + throw QueryExecutionErrors.methodNotDeclaredError(name) + } + methods.head -> expr + } + } + + override def eval(input: InternalRow): Any = { + val instance = beanInstance.eval(input) + if (instance != null) { + val bean = instance.asInstanceOf[Object] + resolvedSetters.foreach { + case (setter, expr) => + val paramVal = expr.eval(input) + // We don't call setter if input value is null. + if (paramVal != null) { + setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) + } + } + } + instance + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val instanceGen = beanInstance.genCode(ctx) + + val javaBeanInstance = ctx.freshName("javaBean") + val beanInstanceJavaType = CodeGenerator.javaType(beanInstance.dataType) + + val initialize = setters.map { + case (setterMethod, fieldValue) => + val fieldGen = fieldValue.genCode(ctx) + s""" + |${fieldGen.code} + |if (!${fieldGen.isNull}) { + | $javaBeanInstance.$setterMethod(${fieldGen.value}); + |} + """.stripMargin + } + val initializeCode = ctx.splitExpressionsWithCurrentInputs( + expressions = initialize.toSeq, + funcName = "initializeJavaBean", + extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil) + + val code = instanceGen.code + + code""" + |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value}; + |if (!${instanceGen.isNull}) { + | $initializeCode + |} + """.stripMargin + ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): InitializeJavaBean = + super.legacyWithNewChildren(newChildren).asInstanceOf[InitializeJavaBean] +} + +/** + * Asserts that input values of a non-nullable child expression are not null. + * + * Note that there are cases where `child.nullable == true`, while we still need to add this + * assertion. Consider a nullable column `s` whose data type is a struct containing a non-nullable + * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all + * non-null `s`, `s.i` can't be null. + */ +case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) + extends UnaryExpression with NonSQLExpression { + + override def dataType: DataType = child.dataType + override def foldable: Boolean = false + override def nullable: Boolean = false + + final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK) + + override def flatArguments: Iterator[Any] = Iterator(child) + + private val errMsg = "Null value appeared in non-nullable field:" + + walkedTypePath.mkString("\n", "\n", "\n") + + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + + "please try to use scala.Option[_] or other nullable types " + + "(e.g. java.lang.Integer instead of int/scala.Int)." + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (result == null) { + throw new NullPointerException(errMsg) + } + result + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + + // Use unnamed reference that doesn't create a local field here to reduce the number of fields + // because errMsgField is used only when the value is null. + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + + val code = childGen.code + code""" + if (${childGen.isNull}) { + throw new NullPointerException($errMsgField); + } + """ + ev.copy(code = code, isNull = FalseLiteral, value = childGen.value) + } + + override protected def withNewChildInternal(newChild: Expression): AssertNotNull = + copy(child = newChild) +} + +/** + * Returns the value of field at index `index` from the external row `child`. + * This class can be viewed as [[GetStructField]] for [[Row]]s instead of [[InternalRow]]s. + * + * Note that the input row and the field we try to get are both guaranteed to be not null, if they + * are null, a runtime exception will be thrown. + */ +case class GetExternalRowField( + child: Expression, + index: Int, + fieldName: String) extends UnaryExpression with NonSQLExpression { + + override def nullable: Boolean = false + + override def dataType: DataType = ObjectType(classOf[Object]) + + private val errMsg = QueryExecutionErrors.fieldCannotBeNullMsg(index, fieldName) + + override def eval(input: InternalRow): Any = { + val inputRow = child.eval(input).asInstanceOf[Row] + if (inputRow == null) { + throw QueryExecutionErrors.inputExternalRowCannotBeNullError + } + if (inputRow.isNullAt(index)) { + throw QueryExecutionErrors.fieldCannotBeNullError(index, fieldName) + } + inputRow.get(index) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // Use unnamed reference that doesn't create a local field here to reduce the number of fields + // because errMsgField is used only when the field is null. + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + val row = child.genCode(ctx) + val code = code""" + ${row.code} + + if (${row.isNull}) { + throw QueryExecutionErrors.inputExternalRowCannotBeNullError(); + } + + if (${row.value}.isNullAt($index)) { + throw new RuntimeException($errMsgField); + } + + final Object ${ev.value} = ${row.value}.get($index); + """ + ev.copy(code = code, isNull = FalseLiteral) + } + + override protected def withNewChildInternal(newChild: Expression): GetExternalRowField = + copy(child = newChild) +} + +/** + * Validates the actual data type of input expression at runtime. If it doesn't match the + * expectation, throw an exception. + */ +case class ValidateExternalType(child: Expression, expected: DataType) + extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ObjectType(classOf[Object])) + + override def nullable: Boolean = child.nullable + + override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected) + + private lazy val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + + private lazy val checkType: (Any) => Boolean = expected match { + case _: DecimalType => + (value: Any) => { + value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] || + value.isInstanceOf[Decimal] + } + case _: ArrayType => + (value: Any) => { + value.getClass.isArray || value.isInstanceOf[Seq[_]] + } + case _ => + val dataTypeClazz = ScalaReflection.javaBoxedType(dataType) + (value: Any) => { + dataTypeClazz.isInstance(value) + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (checkType(result)) { + result + } else { + throw new RuntimeException(s"${result.getClass.getName}$errMsg") + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // Use unnamed reference that doesn't create a local field here to reduce the number of fields + // because errMsgField is used only when the type doesn't match. + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + val input = child.genCode(ctx) + val obj = input.value + + val typeCheck = expected match { + case _: DecimalType => + Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) + .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") + case _: ArrayType => + s"$obj.getClass().isArray() || $obj instanceof ${classOf[scala.collection.Seq[_]].getName}" + case _ => + s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" + } + + val code = code""" + ${input.code} + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (!${input.isNull}) { + if ($typeCheck) { + ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj; + } else { + throw new RuntimeException($obj.getClass().getName() + $errMsgField); + } + } + + """ + ev.copy(code = code, isNull = input.isNull) + } + + override protected def withNewChildInternal(newChild: Expression): ValidateExternalType = + copy(child = newChild) +} diff --git a/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala new file mode 100644 index 000000000..ae320f8c6 --- /dev/null +++ b/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -0,0 +1,909 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.Locale +import java.util.regex.{Matcher, MatchResult, Pattern} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.commons.text.StringEscapeUtils + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +abstract class StringRegexExpression extends BinaryExpression + with ImplicitCastInputTypes with NullIntolerant { + + def escape(v: String): String + def matches(regex: Pattern, str: String): Boolean + + override def dataType: DataType = BooleanType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + // try cache foldable pattern + private lazy val cache: Pattern = right match { + case p: Expression if p.foldable => + compile(p.eval().asInstanceOf[UTF8String].toString) + case _ => null + } + + protected def compile(str: String): Pattern = if (str == null) { + null + } else { + // Let it raise exception if couldn't compile the regex string + Pattern.compile(escape(str)) + } + + protected def pattern(str: String) = if (cache == null) compile(str) else cache + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val regex = pattern(input2.asInstanceOf[UTF8String].toString) + if(regex == null) { + null + } else { + matches(regex, input1.asInstanceOf[UTF8String].toString) + } + } +} + +// scalastyle:off line.contains.tab +/** + * Simple RegEx pattern matching function + */ +@ExpressionDescription( + usage = "str _FUNC_ pattern[ ESCAPE escape] - Returns true if str matches `pattern` with " + + "`escape`, null if any arguments are null, false otherwise.", + arguments = """ + Arguments: + * str - a string expression + * pattern - a string expression. The pattern is a string which is matched literally, with + exception to the following special symbols: + + _ matches any one character in the input (similar to . in posix regular expressions) + + % matches zero or more characters in the input (similar to .* in posix regular + expressions) + + Since Spark 2.0, string literals are unescaped in our SQL parser. For example, in order + to match "\abc", the pattern should be "\\abc". + + When SQL config 'spark.sql.parser.escapedStringLiterals' is enabled, it falls back + to Spark 1.6 behavior regarding string literal parsing. For example, if the config is + enabled, the pattern to match "\abc" should be "\abc". + * escape - an character added since Spark 3.0. The default escape character is the '\'. + If an escape character precedes a special symbol or another escape character, the + following character is matched literally. It is invalid to escape any other character. + """, + examples = """ + Examples: + > SELECT _FUNC_('Spark', '_park'); + true + > SET spark.sql.parser.escapedStringLiterals=true; + spark.sql.parser.escapedStringLiterals true + > SELECT '%SystemDrive%\Users\John' _FUNC_ '\%SystemDrive\%\\Users%'; + true + > SET spark.sql.parser.escapedStringLiterals=false; + spark.sql.parser.escapedStringLiterals false + > SELECT '%SystemDrive%\\Users\\John' _FUNC_ '\%SystemDrive\%\\\\Users%'; + true + > SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/'; + true + """, + note = """ + Use RLIKE to match with standard regular expressions. + """, + since = "1.0.0", + group = "predicate_funcs") +// scalastyle:on line.contains.tab +case class Like(left: Expression, right: Expression, escapeChar: Char) + extends StringRegexExpression { + + def this(left: Expression, right: Expression) = this(left, right, '\\') + + override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar) + + override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() + + final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY) + + override def toString: String = escapeChar match { + case '\\' => s"$left LIKE $right" + case c => s"$left LIKE $right ESCAPE '$c'" + } + + override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val patternClass = classOf[Pattern].getName + val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) + val pattern = ctx.addMutableState(patternClass, "patternLike", + v => s"""$v = $patternClass.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.genCode(ctx) + ev.copy(code = code""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); + } + """) + } else { + ev.copy(code = code""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + """) + } + } else { + val pattern = ctx.freshName("pattern") + val rightStr = ctx.freshName("rightStr") + // We need to escape the escapeChar to make sure the generated code is valid. + // Otherwise we'll hit org.codehaus.commons.compiler.CompileException. + val escapedEscapeChar = StringEscapeUtils.escapeJava(escapeChar.toString) + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String $rightStr = $eval2.toString(); + $patternClass $pattern = $patternClass.compile( + $escapeFunc($rightStr, '$escapedEscapeChar')); + ${ev.value} = $pattern.matcher($eval1.toString()).matches(); + """ + }) + } + } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Like = + copy(left = newLeft, right = newRight) +} + +sealed abstract class MultiLikeBase + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { + + protected def patterns: Seq[UTF8String] + + protected def isNotSpecified: Boolean + + override def inputTypes: Seq[DataType] = StringType :: Nil + + override def dataType: DataType = BooleanType + + override def nullable: Boolean = true + + final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY) + + protected lazy val hasNull: Boolean = patterns.contains(null) + + protected lazy val cache = patterns.filterNot(_ == null) + .map(s => Pattern.compile(StringUtils.escapeLikeRegex(s.toString, '\\'))) + + protected lazy val matchFunc = if (isNotSpecified) { + (p: Pattern, inputValue: String) => !p.matcher(inputValue).matches() + } else { + (p: Pattern, inputValue: String) => p.matcher(inputValue).matches() + } + + protected def matches(exprValue: String): Any + + override def eval(input: InternalRow): Any = { + val exprValue = child.eval(input) + if (exprValue == null) { + null + } else { + matches(exprValue.toString) + } + } +} + +/** + * Optimized version of LIKE ALL, when all pattern values are literal. + */ +sealed abstract class LikeAllBase extends MultiLikeBase { + + override def matches(exprValue: String): Any = { + if (cache.forall(matchFunc(_, exprValue))) { + if (hasNull) null else true + } else { + false + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval = child.genCode(ctx) + val patternClass = classOf[Pattern].getName + val javaDataType = CodeGenerator.javaType(child.dataType) + val pattern = ctx.freshName("pattern") + val valueArg = ctx.freshName("valueArg") + val patternCache = ctx.addReferenceObj("patternCache", cache.asJava) + + val checkNotMatchCode = if (isNotSpecified) { + s"$pattern.matcher($valueArg.toString()).matches()" + } else { + s"!$pattern.matcher($valueArg.toString()).matches()" + } + + ev.copy(code = + code""" + |${eval.code} + |boolean ${ev.isNull} = false; + |boolean ${ev.value} = true; + |if (${eval.isNull}) { + | ${ev.isNull} = true; + |} else { + | $javaDataType $valueArg = ${eval.value}; + | for ($patternClass $pattern: $patternCache) { + | if ($checkNotMatchCode) { + | ${ev.value} = false; + | break; + | } + | } + | if (${ev.value} && $hasNull) ${ev.isNull} = true; + |} + """.stripMargin) + } +} + +case class LikeAll(child: Expression, patterns: Seq[UTF8String]) extends LikeAllBase { + override def isNotSpecified: Boolean = false + override protected def withNewChildInternal(newChild: Expression): LikeAll = + copy(child = newChild) +} + +case class NotLikeAll(child: Expression, patterns: Seq[UTF8String]) extends LikeAllBase { + override def isNotSpecified: Boolean = true + override protected def withNewChildInternal(newChild: Expression): NotLikeAll = + copy(child = newChild) +} + +/** + * Optimized version of LIKE ANY, when all pattern values are literal. + */ +sealed abstract class LikeAnyBase extends MultiLikeBase { + + override def matches(exprValue: String): Any = { + if (cache.exists(matchFunc(_, exprValue))) { + true + } else { + if (hasNull) null else false + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval = child.genCode(ctx) + val patternClass = classOf[Pattern].getName + val javaDataType = CodeGenerator.javaType(child.dataType) + val pattern = ctx.freshName("pattern") + val valueArg = ctx.freshName("valueArg") + val patternCache = ctx.addReferenceObj("patternCache", cache.asJava) + + val checkMatchCode = if (isNotSpecified) { + s"!$pattern.matcher($valueArg.toString()).matches()" + } else { + s"$pattern.matcher($valueArg.toString()).matches()" + } + + ev.copy(code = + code""" + |${eval.code} + |boolean ${ev.isNull} = false; + |boolean ${ev.value} = false; + |if (${eval.isNull}) { + | ${ev.isNull} = true; + |} else { + | $javaDataType $valueArg = ${eval.value}; + | for ($patternClass $pattern: $patternCache) { + | if ($checkMatchCode) { + | ${ev.value} = true; + | break; + | } + | } + | if (!${ev.value} && $hasNull) ${ev.isNull} = true; + |} + """.stripMargin) + } +} + +case class LikeAny(child: Expression, patterns: Seq[UTF8String]) extends LikeAnyBase { + override def isNotSpecified: Boolean = false + override protected def withNewChildInternal(newChild: Expression): LikeAny = + copy(child = newChild) +} + +case class NotLikeAny(child: Expression, patterns: Seq[UTF8String]) extends LikeAnyBase { + override def isNotSpecified: Boolean = true + override protected def withNewChildInternal(newChild: Expression): NotLikeAny = + copy(child = newChild) +} + +// scalastyle:off line.contains.tab +@ExpressionDescription( + usage = "_FUNC_(str, regexp) - Returns true if `str` matches `regexp`, or false otherwise.", + arguments = """ + Arguments: + * str - a string expression + * regexp - a string expression. The regex string should be a Java regular expression. + + Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL + parser. For example, to match "\abc", a regular expression for `regexp` can be + "^\\abc$". + + There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to + fallback to the Spark 1.6 behavior regarding string literal parsing. For example, + if the config is enabled, the `regexp` that can match "\abc" is "^\abc$". + """, + examples = """ + Examples: + > SET spark.sql.parser.escapedStringLiterals=true; + spark.sql.parser.escapedStringLiterals true + > SELECT _FUNC_('%SystemDrive%\Users\John', '%SystemDrive%\\Users.*'); + true + > SET spark.sql.parser.escapedStringLiterals=false; + spark.sql.parser.escapedStringLiterals false + > SELECT _FUNC_('%SystemDrive%\\Users\\John', '%SystemDrive%\\\\Users.*'); + true + """, + note = """ + Use LIKE to match with simple string pattern. + """, + since = "1.0.0", + group = "predicate_funcs") +// scalastyle:on line.contains.tab +case class RLike(left: Expression, right: Expression) extends StringRegexExpression { + + override def escape(v: String): String = v + override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) + override def toString: String = s"RLIKE($left, $right)" + override def sql: String = s"${prettyName.toUpperCase(Locale.ROOT)}(${left.sql}, ${right.sql})" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val patternClass = classOf[Pattern].getName + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) + val pattern = ctx.addMutableState(patternClass, "patternRLike", + v => s"""$v = $patternClass.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.genCode(ctx) + ev.copy(code = code""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); + } + """) + } else { + ev.copy(code = code""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + """) + } + } else { + val rightStr = ctx.freshName("rightStr") + val pattern = ctx.freshName("pattern") + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String $rightStr = $eval2.toString(); + $patternClass $pattern = $patternClass.compile($rightStr); + ${ev.value} = $pattern.matcher($eval1.toString()).find(0); + """ + }) + } + } + + override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): RLike = + copy(left = newLeft, right = newRight) +} + + +/** + * Splits str around matches of the given regex. + */ +@ExpressionDescription( + usage = "_FUNC_(str, regex, limit) - Splits `str` around occurrences that match `regex`" + + " and returns an array with a length of at most `limit`", + arguments = """ + Arguments: + * str - a string expression to split. + * regex - a string representing a regular expression. The regex string should be a + Java regular expression. + * limit - an integer expression which controls the number of times the regex is applied. + * limit > 0: The resulting array's length will not be more than `limit`, + and the resulting array's last entry will contain all input + beyond the last matched regex. + * limit <= 0: `regex` will be applied as many times as possible, and + the resulting array can be of any size. + """, + examples = """ + Examples: + > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]'); + ["one","two","three",""] + > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', -1); + ["one","two","three",""] + > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', 2); + ["one","twoBthreeC"] + """, + since = "1.5.0", + group = "string_funcs") +case class StringSplit(str: Expression, regex: Expression, limit: Expression) + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + + override def dataType: DataType = ArrayType(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + override def first: Expression = str + override def second: Expression = regex + override def third: Expression = limit + + def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1)); + + override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = { + val strings = string.asInstanceOf[UTF8String].split( + regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int]) + new GenericArrayData(strings.asInstanceOf[Array[Any]]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayClass = classOf[GenericArrayData].getName + nullSafeCodeGen(ctx, ev, (str, regex, limit) => { + // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. + s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin + }) + } + + override def prettyName: String = "split" + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): StringSplit = + copy(str = newFirst, regex = newSecond, limit = newThird) +} + + +/** + * Replace all substrings of str that match regexp with rep. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str, regexp, rep[, position]) - Replaces all substrings of `str` that match `regexp` with `rep`.", + arguments = """ + Arguments: + * str - a string expression to search for a regular expression pattern match. + * regexp - a string representing a regular expression. The regex string should be a + Java regular expression. + + Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL + parser. For example, to match "\abc", a regular expression for `regexp` can be + "^\\abc$". + + There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to + fallback to the Spark 1.6 behavior regarding string literal parsing. For example, + if the config is enabled, the `regexp` that can match "\abc" is "^\abc$". + * rep - a string expression to replace matched substrings. + * position - a positive integer literal that indicates the position within `str` to begin searching. + The default is 1. If position is greater than the number of characters in `str`, the result is `str`. + """, + examples = """ + Examples: + > SELECT _FUNC_('100-200', '(\\d+)', 'num'); + num-num + """, + since = "1.5.0", + group = "string_funcs") +// scalastyle:on line.size.limit +case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression, pos: Expression) + extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant { + + def this(subject: Expression, regexp: Expression, rep: Expression) = + this(subject, regexp, rep, Literal(1)) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!pos.foldable) { + return TypeCheckFailure(s"Position expression must be foldable, but got $pos") + } + + val posEval = pos.eval() + if (posEval == null || posEval.asInstanceOf[Int] > 0) { + TypeCheckSuccess + } else { + TypeCheckFailure(s"Position expression must be positive, but got: $posEval") + } + } + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + // last replacement string, we don't want to convert a UTF8String => java.langString every time. + @transient private var lastReplacement: String = _ + @transient private var lastReplacementInUTF8: UTF8String = _ + // result buffer write by Matcher + @transient private lazy val result: StringBuffer = new StringBuffer + final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_REPLACE) + + override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String].clone() + pattern = Pattern.compile(lastRegex.toString) + } + if (!r.equals(lastReplacementInUTF8)) { + // replacement string changed + lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() + lastReplacement = lastReplacementInUTF8.toString + } + val source = s.toString() + val position = i.asInstanceOf[Int] - 1 + if (position < source.length) { + val m = pattern.matcher(source) + m.region(position, source.length) + result.delete(0, result.length()) + while (m.find) { + m.appendReplacement(result, lastReplacement) + } + m.appendTail(result) + UTF8String.fromString(result.toString) + } else { + s + } + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = + Seq(StringType, StringType, StringType, IntegerType) + override def prettyName: String = "regexp_replace" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val termResult = ctx.freshName("termResult") + + val classNamePattern = classOf[Pattern].getCanonicalName + val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName + + val matcher = ctx.freshName("matcher") + val source = ctx.freshName("source") + val position = ctx.freshName("position") + + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") + val termPattern = ctx.addMutableState(classNamePattern, "pattern") + val termLastReplacement = ctx.addMutableState("String", "lastReplacement") + val termLastReplacementInUTF8 = ctx.addMutableState("UTF8String", "lastReplacementInUTF8") + + val setEvNotNull = if (nullable) { + s"${ev.isNull} = false;" + } else { + "" + } + + nullSafeCodeGen(ctx, ev, (subject, regexp, rep, pos) => { + s""" + if (!$regexp.equals($termLastRegex)) { + // regex value changed + $termLastRegex = $regexp.clone(); + $termPattern = $classNamePattern.compile($termLastRegex.toString()); + } + if (!$rep.equals($termLastReplacementInUTF8)) { + // replacement string changed + $termLastReplacementInUTF8 = $rep.clone(); + $termLastReplacement = $termLastReplacementInUTF8.toString(); + } + String $source = $subject.toString(); + int $position = $pos - 1; + if ($position < $source.length()) { + $classNameStringBuffer $termResult = new $classNameStringBuffer(); + java.util.regex.Matcher $matcher = $termPattern.matcher($source); + $matcher.region($position, $source.length()); + + while ($matcher.find()) { + $matcher.appendReplacement($termResult, $termLastReplacement); + } + $matcher.appendTail($termResult); + ${ev.value} = UTF8String.fromString($termResult.toString()); + $termResult = null; + } else { + ${ev.value} = $subject; + } + $setEvNotNull + """ + }) + } + + override def first: Expression = subject + override def second: Expression = regexp + override def third: Expression = rep + override def fourth: Expression = pos + + override protected def withNewChildrenInternal( + first: Expression, second: Expression, third: Expression, fourth: Expression): RegExpReplace = + copy(subject = first, regexp = second, rep = third, pos = fourth) +} + +object RegExpReplace { + def apply(subject: Expression, regexp: Expression, rep: Expression): RegExpReplace = + new RegExpReplace(subject, regexp, rep) +} + +object RegExpExtractBase { + def checkGroupIndex(groupCount: Int, groupIndex: Int): Unit = { + if (groupIndex < 0) { + throw QueryExecutionErrors.regexGroupIndexLessThanZeroError + } else if (groupCount < groupIndex) { + throw QueryExecutionErrors.regexGroupIndexExceedGroupCountError( + groupCount, groupIndex) + } + } +} + +abstract class RegExpExtractBase + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + def subject: Expression + def regexp: Expression + def idx: Expression + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + + final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY) + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + override def first: Expression = subject + override def second: Expression = regexp + override def third: Expression = idx + + protected def getLastMatcher(s: Any, p: Any): Matcher = { + if (p != lastRegex) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String].clone() + pattern = Pattern.compile(lastRegex.toString) + } + pattern.matcher(s.toString) + } +} + +/** + * Extract a specific(idx) group identified by a Java regex. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +@ExpressionDescription( + usage = """ + _FUNC_(str, regexp[, idx]) - Extract the first string in the `str` that match the `regexp` + expression and corresponding to the regex group index. + """, + arguments = """ + Arguments: + * str - a string expression. + * regexp - a string representing a regular expression. The regex string should be a + Java regular expression. + + Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL + parser. For example, to match "\abc", a regular expression for `regexp` can be + "^\\abc$". + + There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to + fallback to the Spark 1.6 behavior regarding string literal parsing. For example, + if the config is enabled, the `regexp` that can match "\abc" is "^\abc$". + * idx - an integer expression that representing the group index. The regex maybe contains + multiple groups. `idx` indicates which regex group to extract. The group index should + be non-negative. The minimum value of `idx` is 0, which means matching the entire + regular expression. If `idx` is not specified, the default group index value is 1. The + `idx` parameter is the Java regex Matcher group() method index. + """, + examples = """ + Examples: + > SELECT _FUNC_('100-200', '(\\d+)-(\\d+)', 1); + 100 + """, + since = "1.5.0", + group = "string_funcs") +case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) + extends RegExpExtractBase { + def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + val m = getLastMatcher(s, p) + if (m.find) { + val mr: MatchResult = m.toMatchResult + val index = r.asInstanceOf[Int] + RegExpExtractBase.checkGroupIndex(mr.groupCount, index) + val group = mr.group(index) + if (group == null) { // Pattern matched, but it's an optional group + UTF8String.EMPTY_UTF8 + } else { + UTF8String.fromString(group) + } + } else { + UTF8String.EMPTY_UTF8 + } + } + + override def dataType: DataType = StringType + override def prettyName: String = "regexp_extract" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val classNamePattern = classOf[Pattern].getCanonicalName + val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName + val matcher = ctx.freshName("matcher") + val matchResult = ctx.freshName("matchResult") + + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") + val termPattern = ctx.addMutableState(classNamePattern, "pattern") + + val setEvNotNull = if (nullable) { + s"${ev.isNull} = false;" + } else { + "" + } + + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { + s""" + if (!$regexp.equals($termLastRegex)) { + // regex value changed + $termLastRegex = $regexp.clone(); + $termPattern = $classNamePattern.compile($termLastRegex.toString()); + } + java.util.regex.Matcher $matcher = + $termPattern.matcher($subject.toString()); + if ($matcher.find()) { + java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); + $classNameRegExpExtractBase.checkGroupIndex($matchResult.groupCount(), $idx); + if ($matchResult.group($idx) == null) { + ${ev.value} = UTF8String.EMPTY_UTF8; + } else { + ${ev.value} = UTF8String.fromString($matchResult.group($idx)); + } + $setEvNotNull + } else { + ${ev.value} = UTF8String.EMPTY_UTF8; + $setEvNotNull + }""" + }) + } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): RegExpExtract = + copy(subject = newFirst, regexp = newSecond, idx = newThird) +} + +/** + * Extract all specific(idx) groups identified by a Java regex. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +@ExpressionDescription( + usage = """ + _FUNC_(str, regexp[, idx]) - Extract all strings in the `str` that match the `regexp` + expression and corresponding to the regex group index. + """, + arguments = """ + Arguments: + * str - a string expression. + * regexp - a string representing a regular expression. The regex string should be a + Java regular expression. + + Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL + parser. For example, to match "\abc", a regular expression for `regexp` can be + "^\\abc$". + + There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to + fallback to the Spark 1.6 behavior regarding string literal parsing. For example, + if the config is enabled, the `regexp` that can match "\abc" is "^\abc$". + * idx - an integer expression that representing the group index. The regex may contains + multiple groups. `idx` indicates which regex group to extract. The group index should + be non-negative. The minimum value of `idx` is 0, which means matching the entire + regular expression. If `idx` is not specified, the default group index value is 1. The + `idx` parameter is the Java regex Matcher group() method index. + """, + examples = """ + Examples: + > SELECT _FUNC_('100-200, 300-400', '(\\d+)-(\\d+)', 1); + ["100","300"] + """, + since = "3.1.0", + group = "string_funcs") +case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expression) + extends RegExpExtractBase { + def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + val m = getLastMatcher(s, p) + val matchResults = new ArrayBuffer[UTF8String]() + while(m.find) { + val mr: MatchResult = m.toMatchResult + val index = r.asInstanceOf[Int] + RegExpExtractBase.checkGroupIndex(mr.groupCount, index) + val group = mr.group(index) + if (group == null) { // Pattern matched, but it's an optional group + matchResults += UTF8String.EMPTY_UTF8 + } else { + matchResults += UTF8String.fromString(group) + } + } + + new GenericArrayData(matchResults.toArray.asInstanceOf[Array[Any]]) + } + + override def dataType: DataType = ArrayType(StringType) + override def prettyName: String = "regexp_extract_all" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val classNamePattern = classOf[Pattern].getCanonicalName + val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName + val arrayClass = classOf[GenericArrayData].getName + val matcher = ctx.freshName("matcher") + val matchResult = ctx.freshName("matchResult") + val matchResults = ctx.freshName("matchResults") + + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") + val termPattern = ctx.addMutableState(classNamePattern, "pattern") + + val setEvNotNull = if (nullable) { + s"${ev.isNull} = false;" + } else { + "" + } + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { + s""" + | if (!$regexp.equals($termLastRegex)) { + | // regex value changed + | $termLastRegex = $regexp.clone(); + | $termPattern = $classNamePattern.compile($termLastRegex.toString()); + | } + | java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString()); + | java.util.ArrayList $matchResults = new java.util.ArrayList(); + | while ($matcher.find()) { + | java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); + | $classNameRegExpExtractBase.checkGroupIndex($matchResult.groupCount(), $idx); + | if ($matchResult.group($idx) == null) { + | $matchResults.add(UTF8String.EMPTY_UTF8); + | } else { + | $matchResults.add(UTF8String.fromString($matchResult.group($idx))); + | } + | } + | ${ev.value} = + | new $arrayClass($matchResults.toArray(new UTF8String[$matchResults.size()])); + | $setEvNotNull + """ + }) + } + + override protected def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): RegExpExtractAll = + copy(subject = newFirst, regexp = newSecond, idx = newThird) +} diff --git a/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala new file mode 100644 index 000000000..8063d7f1c --- /dev/null +++ b/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, Complete} +import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Insert a filter on one side of the join if the other side has a selective predicate. + * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something + * else in the future. + */ +object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper { + + // Wraps `expr` with a hash function if its byte size is larger than an integer. + private def mayWrapWithHash(expr: Expression): Expression = { + if (expr.dataType.defaultSize > IntegerType.defaultSize) { + new Murmur3Hash(Seq(expr)) + } else { + expr + } + } + + private def injectFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan + ): LogicalPlan = { + require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled) + if (conf.runtimeFilterBloomFilterEnabled) { + injectBloomFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } else { + injectInSubqueryFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } + } + + private def injectBloomFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan + ): LogicalPlan = { + // Skip if the filter creation side is too big + if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterBloomFilterThreshold) { + return filterApplicationSidePlan + } + val rowCount = filterCreationSidePlan.stats.rowCount + val bloomFilterAgg = + if (rowCount.isDefined && rowCount.get.longValue > 0L) { + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), + Literal(rowCount.get.longValue)) + } else { + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) + } + val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) + val alias = Alias(aggExp, "bloomFilter")() + val aggregate = ConstantFolding(Aggregate(Nil, Seq(alias), filterCreationSidePlan)) + val bloomFilterSubquery = ScalarSubquery(aggregate, Nil) + val filter = BloomFilterMightContain(bloomFilterSubquery, + new XxHash64(Seq(filterApplicationSideExp))) + Filter(filter, filterApplicationSidePlan) + } + + private def injectInSubqueryFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan + ): LogicalPlan = { + require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) + val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) + val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() + val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan) + if (!canBroadcastBySize(aggregate, conf)) { + // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold, + // i.e., the semi-join will be a shuffled join, which is not worthwhile. + return filterApplicationSidePlan + } + val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)), + ListQuery(aggregate, childOutputs = aggregate.output)) + Filter(filter, filterApplicationSidePlan) + } + + /** + * Returns whether the plan is a simple filter over scan and the filter is likely selective + * Also check if the plan only has simple expressions (attribute reference, literals) so that we + * do not add a subquery that might have an expensive computation + */ + private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { + plan.expressions + val ret = plan match { + case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] => + filters.forall(isSimpleExpression) && + filters.exists(isLikelySelective) + case _ => false + } + !plan.isStreaming && ret + } + + private def isSimpleExpression(e: Expression): Boolean = { + !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, + REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE) + } + + /** + * Returns whether an expression is likely to be selective + */ + private def isLikelySelective(e: Expression): Boolean = e match { + case Not(expr) => isLikelySelective(expr) + case And(l, r) => isLikelySelective(l) || isLikelySelective(r) + case Or(l, r) => isLikelySelective(l) && isLikelySelective(r) + case _: StringRegexExpression => true + case _: BinaryComparison => true + case _: In | _: InSet => true + case _: StringPredicate => true + case _: MultiLikeBase => true + case _ => false + } + + private def canFilterLeft(joinType: JoinType): Boolean = joinType match { + case Inner | RightOuter => true + case _ => false + } + + private def canFilterRight(joinType: JoinType): Boolean = joinType match { + case Inner | LeftOuter => true + case _ => false + } + + private def isProbablyShuffleJoin(left: LogicalPlan, + right: LogicalPlan, hint: JoinHint): Boolean = { + !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) && + !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf) + } + + private def probablyHasShuffle(plan: LogicalPlan): Boolean = { + plan.collect { + case j@Join(left, right, _, _, hint) + if !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) && + !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf) => j + case a: Aggregate => a + }.nonEmpty + } + + // Returns the max scan byte size in the subtree rooted at `filterApplicationSide`. + private def maxScanByteSize(filterApplicationSide: LogicalPlan): BigInt = { + val defaultSizeInBytes = conf.getConf(SQLConf.DEFAULT_SIZE_IN_BYTES) + filterApplicationSide.collect({ + case leaf: LeafNode => leaf + }).map(scan => { + // DEFAULT_SIZE_IN_BYTES means there's no byte size information in stats. Since we avoid + // creating a Bloom filter when the filter application side is very small, so using 0 + // as the byte size when the actual size is unknown can avoid regression by applying BF + // on a small table. + if (scan.stats.sizeInBytes == defaultSizeInBytes) BigInt(0) else scan.stats.sizeInBytes + }).max + } + + // Returns true if `filterApplicationSide` satisfies the byte size requirement to apply a + // Bloom filter; false otherwise. + private def satisfyByteSizeRequirement(filterApplicationSide: LogicalPlan): Boolean = { + // In case `filterApplicationSide` is a union of many small tables, disseminating the Bloom + // filter to each small task might be more costly than scanning them itself. Thus, we use max + // rather than sum here. + val maxScanSize = maxScanByteSize(filterApplicationSide) + maxScanSize >= + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD) + } + + private def filteringHasBenefit( + filterApplicationSide: LogicalPlan, + filterCreationSide: LogicalPlan, + filterApplicationSideExp: Expression, + hint: JoinHint): Boolean = { + // Check that: + // 1. The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the + // expression references originate from a single leaf node) + // 2. The filter creation side has a selective predicate + // 3. The current join is a shuffle join or a broadcast join that has a shuffle or aggregate + // in the filter application side + // 4. The filterApplicationSide is larger than the filterCreationSide by a configurable + // threshold + findExpressionAndTrackLineageDown(filterApplicationSideExp, + filterApplicationSide).isDefined && isSelectiveFilterOverScan(filterCreationSide) && + (isProbablyShuffleJoin(filterApplicationSide, filterCreationSide, hint) || + probablyHasShuffle(filterApplicationSide)) && + satisfyByteSizeRequirement(filterApplicationSide) + } + + def hasRuntimeFilter(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + rightKey: Expression): Boolean = { + if (conf.runtimeFilterBloomFilterEnabled) { + hasBloomFilter(left, right, leftKey, rightKey) + } else { + hasInSubquery(left, right, leftKey, rightKey) + } + } + + // This checks if there is already a DPP filter, as this rule is called just after DPP. + def hasDynamicPruningSubquery(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + rightKey: Expression): Boolean = { + (left, right) match { + case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) => + pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey) + case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) => + pruningKey.fastEquals(rightKey) || + hasDynamicPruningSubquery(left, plan, leftKey, rightKey) + case _ => false + } + } + + def hasBloomFilter(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + rightKey: Expression): Boolean = { + findBloomFilterWithExp(left, leftKey) || findBloomFilterWithExp(right, rightKey) + } + + private def findBloomFilterWithExp(plan: LogicalPlan, key: Expression): Boolean = { + plan.find { + case Filter(condition, _) => + splitConjunctivePredicates(condition).exists { + case BloomFilterMightContain(_, XxHash64(Seq(valueExpression), _)) + if valueExpression.fastEquals(key) => true + case _ => false + } + case _ => false + }.isDefined + } + + def hasInSubquery(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + rightKey: Expression): Boolean = { + (left, right) match { + case (Filter(InSubquery(Seq(key), + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) => + key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey))) + case (_, Filter(InSubquery(Seq(key), + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) => + key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey))) + case _ => false + } + } + + private def tryInjectRuntimeFilter(plan: LogicalPlan): LogicalPlan = { + var filterCounter = 0 + val numFilterThreshold = conf.getConf(SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD) + plan transformUp { + case join @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, left, right, hint) => + var newLeft = left + var newRight = right + (leftKeys, rightKeys).zipped.foreach((l, r) => { + // Check if: + // 1. There is already a DPP filter on the key + // 2. There is already a runtime filter (Bloom filter or IN subquery) on the key + // 3. The keys are simple cheap expressions + if (filterCounter < numFilterThreshold && + !hasDynamicPruningSubquery(left, right, l, r) && + !hasRuntimeFilter(newLeft, newRight, l, r) && + isSimpleExpression(l) && isSimpleExpression(r)) { + if (canFilterLeft(joinType) && filteringHasBenefit(left, right, l, hint)) { + newLeft = injectFilter(l, newLeft, r, right) + filterCounter = filterCounter + 1 + } else if (canFilterRight(joinType) && filteringHasBenefit(right, left, r, hint)) { + newRight = injectFilter(r, newRight, l, left) + filterCounter = filterCounter + 1 + } + } + }) + Join(newLeft, newRight, joinType, join.condition, hint) + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + case s: Subquery if s.correlated => plan + case _ if !conf.runtimeFilterSemiJoinReductionEnabled && + !conf.runtimeFilterBloomFilterEnabled => plan + case _ => tryInjectRuntimeFilter(plan) + } + +} diff --git a/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala new file mode 100644 index 000000000..c3a0a90d8 --- /dev/null +++ b/shims/spark321/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.trees + +// Enums for commonly encountered tree patterns in rewrite rules. +object TreePattern extends Enumeration { + type TreePattern = Value + + // Enum Ids start from 0. + // Expression patterns (alphabetically ordered) + val AGGREGATE_EXPRESSION = Value(0) + val ALIAS: Value = Value + val AND_OR: Value = Value + val ARRAYS_ZIP: Value = Value + val ATTRIBUTE_REFERENCE: Value = Value + val APPEND_COLUMNS: Value = Value + val AVERAGE: Value = Value + val GROUPING_ANALYTICS: Value = Value + val BINARY_ARITHMETIC: Value = Value + val BINARY_COMPARISON: Value = Value + val BOOL_AGG: Value = Value + val CASE_WHEN: Value = Value + val CAST: Value = Value + val CONCAT: Value = Value + val COUNT: Value = Value + val COUNT_IF: Value = Value + val CREATE_NAMED_STRUCT: Value = Value + val CURRENT_LIKE: Value = Value + val DESERIALIZE_TO_OBJECT: Value = Value + val DYNAMIC_PRUNING_EXPRESSION: Value = Value + val DYNAMIC_PRUNING_SUBQUERY: Value = Value + val EXISTS_SUBQUERY = Value + val EXPRESSION_WITH_RANDOM_SEED: Value = Value + val EXTRACT_VALUE: Value = Value + val GENERATE: Value = Value + val GENERATOR: Value = Value + val HIGH_ORDER_FUNCTION: Value = Value + val IF: Value = Value + val IN: Value = Value + val IN_SUBQUERY: Value = Value + val INSET: Value = Value + val INTERSECT: Value = Value + val INVOKE: Value = Value + val JSON_TO_STRUCT: Value = Value + val LAMBDA_FUNCTION: Value = Value + val LAMBDA_VARIABLE: Value = Value + val LATERAL_SUBQUERY: Value = Value + val LIKE_FAMLIY: Value = Value + val LIST_SUBQUERY: Value = Value + val LITERAL: Value = Value + val MAP_OBJECTS: Value = Value + val MULTI_ALIAS: Value = Value + val NEW_INSTANCE: Value = Value + val NOT: Value = Value + val NULL_CHECK: Value = Value + val NULL_LITERAL: Value = Value + val SERIALIZE_FROM_OBJECT: Value = Value + val OUTER_REFERENCE: Value = Value + val PIVOT: Value = Value + val PLAN_EXPRESSION: Value = Value + val PYTHON_UDF: Value = Value + val REGEXP_EXTRACT_FAMILY: Value = Value + val REGEXP_REPLACE: Value = Value + val RUNTIME_REPLACEABLE: Value = Value + val SCALAR_SUBQUERY: Value = Value + val SCALA_UDF: Value = Value + val SORT: Value = Value + val SUBQUERY_ALIAS: Value = Value + val SUM: Value = Value + val TIME_WINDOW: Value = Value + val TIME_ZONE_AWARE_EXPRESSION: Value = Value + val TRUE_OR_FALSE_LITERAL: Value = Value + val WINDOW_EXPRESSION: Value = Value + val UNARY_POSITIVE: Value = Value + val UPDATE_FIELDS: Value = Value + val UPPER_OR_LOWER: Value = Value + val UP_CAST: Value = Value + + // Logical plan patterns (alphabetically ordered) + val AGGREGATE: Value = Value + val COMMAND: Value = Value + val CTE: Value = Value + val DISTINCT_LIKE: Value = Value + val EVENT_TIME_WATERMARK: Value = Value + val EXCEPT: Value = Value + val FILTER: Value = Value + val INNER_LIKE_JOIN: Value = Value + val JOIN: Value = Value + val LATERAL_JOIN: Value = Value + val LEFT_SEMI_OR_ANTI_JOIN: Value = Value + val LIMIT: Value = Value + val LOCAL_RELATION: Value = Value + val LOGICAL_QUERY_STAGE: Value = Value + val NATURAL_LIKE_JOIN: Value = Value + val OUTER_JOIN: Value = Value + val PROJECT: Value = Value + val REPARTITION_OPERATION: Value = Value + val UNION: Value = Value + val UNRESOLVED_RELATION: Value = Value + val TYPED_FILTER: Value = Value + val WINDOW: Value = Value + val WITH_WINDOW_DEFINITION: Value = Value + + // Unresolved expression patterns (Alphabetically ordered) + val UNRESOLVED_ALIAS: Value = Value + val UNRESOLVED_ATTRIBUTE: Value = Value + val UNRESOLVED_DESERIALIZER: Value = Value + val UNRESOLVED_ORDINAL: Value = Value + val UNRESOLVED_FUNCTION: Value = Value + val UNRESOLVED_HINT: Value = Value + val UNRESOLVED_WINDOW_EXPRESSION: Value = Value + + // Unresolved Plan patterns (Alphabetically ordered) + val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value + val UNRESOLVED_FUNC: Value = Value + + // Execution expression patterns (alphabetically ordered) + val IN_SUBQUERY_EXEC: Value = Value + + // Execution Plan patterns (alphabetically ordered) + val EXCHANGE: Value = Value +} diff --git a/shims/spark321/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/shims/spark321/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala new file mode 100644 index 000000000..da11df664 --- /dev/null +++ b/shims/spark321/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.ExperimentalMethods +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions +import org.apache.spark.sql.execution.datasources.SchemaPruning +import org.apache.spark.sql.execution.datasources.v2.{V2ScanRelationPushDown, V2Writes} +import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning} +import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} + +class SparkOptimizer( + catalogManager: CatalogManager, + catalog: SessionCatalog, + experimentalMethods: ExperimentalMethods) + extends Optimizer(catalogManager) { + + override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = + // TODO: move SchemaPruning into catalyst + SchemaPruning :: V2ScanRelationPushDown :: V2Writes :: PruneFileSourcePartitions :: Nil + + override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ + Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ + Batch("PartitionPruning", Once, + PartitionPruning) :+ + Batch("InjectRuntimeFilter", FixedPoint(1), + InjectRuntimeFilter, + RewritePredicateSubquery) :+ + Batch("Pushdown Filters from PartitionPruning", fixedPoint, + PushDownPredicates) :+ + Batch("Cleanup filters that cannot be pushed down", Once, + CleanupDynamicPruningFilters, + PruneFilters)) ++ + postHocOptimizationBatches :+ + Batch("Extract Python UDFs", Once, + ExtractPythonUDFFromJoinCondition, + // `ExtractPythonUDFFromJoinCondition` can convert a join to a cartesian product. + // Here, we rerun cartesian product check. + CheckCartesianProducts, + ExtractPythonUDFFromAggregate, + // This must be executed after `ExtractPythonUDFFromAggregate` and before `ExtractPythonUDFs`. + ExtractGroupingPythonUDFFromAggregate, + ExtractPythonUDFs, + // The eval-python node may be between Project/Filter and the scan node, which breaks + // column pruning and filter push-down. Here we rerun the related optimizer rules. + ColumnPruning, + PushPredicateThroughNonJoin, + RemoveNoopOperators) :+ + Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + + override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+ + ExtractPythonUDFFromJoinCondition.ruleName :+ + ExtractPythonUDFFromAggregate.ruleName :+ ExtractGroupingPythonUDFFromAggregate.ruleName :+ + ExtractPythonUDFs.ruleName :+ + V2ScanRelationPushDown.ruleName :+ + V2Writes.ruleName + + /** + * Optimization batches that are executed before the regular optimization batches (also before + * the finish analysis batch). + */ + def preOptimizationBatches: Seq[Batch] = Nil + + /** + * Optimization batches that are executed after the regular optimization batches, but before the + * batch executing the [[ExperimentalMethods]] optimizer rules. This hook can be used to add + * custom optimizer batches to the Spark optimizer. + * + * Note that 'Extract Python UDFs' batch is an exception and ran after the batches defined here. + */ + def postHocOptimizationBatches: Seq[Batch] = Nil +} diff --git a/shims/spark321/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/shims/spark321/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala new file mode 100644 index 000000000..894df3688 --- /dev/null +++ b/shims/spark321/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -0,0 +1,4356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{Locale, NoSuchElementException, Properties, TimeZone} +import java.util +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference +import java.util.zip.Deflater + +import scala.collection.JavaConverters._ +import scala.collection.immutable +import scala.util.Try +import scala.util.control.NonFatal +import scala.util.matching.Regex + +import org.apache.hadoop.fs.Path + +import org.apache.spark.{SparkConf, SparkContext, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.{IGNORE_MISSING_FILES => SPARK_IGNORE_MISSING_FILES} +import org.apache.spark.network.util.ByteUnit +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.{HintErrorLogger, Resolver} +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.sql.catalyst.plans.logical.HintErrorHandler +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.types.{AtomicType, TimestampNTZType, TimestampType} +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.util.Utils + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines the configuration options for Spark SQL. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +object SQLConf { + + private[this] val sqlConfEntriesUpdateLock = new Object + + @volatile + private[this] var sqlConfEntries: util.Map[String, ConfigEntry[_]] = util.Collections.emptyMap() + + private[this] val staticConfKeysUpdateLock = new Object + + @volatile + private[this] var staticConfKeys: java.util.Set[String] = util.Collections.emptySet() + + private def register(entry: ConfigEntry[_]): Unit = sqlConfEntriesUpdateLock.synchronized { + require(!sqlConfEntries.containsKey(entry.key), + s"Duplicate SQLConfigEntry. ${entry.key} has been registered") + val updatedMap = new java.util.HashMap[String, ConfigEntry[_]](sqlConfEntries) + updatedMap.put(entry.key, entry) + sqlConfEntries = updatedMap + } + + // For testing only + private[sql] def unregister(entry: ConfigEntry[_]): Unit = sqlConfEntriesUpdateLock.synchronized { + val updatedMap = new java.util.HashMap[String, ConfigEntry[_]](sqlConfEntries) + updatedMap.remove(entry.key) + sqlConfEntries = updatedMap + } + + private[internal] def getConfigEntry(key: String): ConfigEntry[_] = { + sqlConfEntries.get(key) + } + + private[internal] def getConfigEntries(): util.Collection[ConfigEntry[_]] = { + sqlConfEntries.values() + } + + private[internal] def containsConfigEntry(entry: ConfigEntry[_]): Boolean = { + getConfigEntry(entry.key) == entry + } + + private[sql] def containsConfigKey(key: String): Boolean = { + sqlConfEntries.containsKey(key) + } + + def registerStaticConfigKey(key: String): Unit = staticConfKeysUpdateLock.synchronized { + val updated = new util.HashSet[String](staticConfKeys) + updated.add(key) + staticConfKeys = updated + } + + def isStaticConfigKey(key: String): Boolean = staticConfKeys.contains(key) + + def buildConf(key: String): ConfigBuilder = ConfigBuilder(key).onCreate(register) + + def buildStaticConf(key: String): ConfigBuilder = { + ConfigBuilder(key).onCreate { entry => + SQLConf.registerStaticConfigKey(entry.key) + SQLConf.register(entry) + } + } + + /** + * Merge all non-static configs to the SQLConf. For example, when the 1st [[SparkSession]] and + * the global [[SharedState]] have been initialized, all static configs have taken affect and + * should not be set to other values. Other later created sessions should respect all static + * configs and only be able to change non-static configs. + */ + private[sql] def mergeNonStaticSQLConfigs( + sqlConf: SQLConf, + configs: Map[String, String]): Unit = { + for ((k, v) <- configs if !staticConfKeys.contains(k)) { + sqlConf.setConfString(k, v) + } + } + + /** + * Extract entries from `SparkConf` and put them in the `SQLConf` + */ + private[sql] def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = { + sparkConf.getAll.foreach { case (k, v) => + sqlConf.setConfString(k, v) + } + } + + /** + * Default config. Only used when there is no active SparkSession for the thread. + * See [[get]] for more information. + */ + private lazy val fallbackConf = new ThreadLocal[SQLConf] { + override def initialValue: SQLConf = new SQLConf + } + + /** See [[get]] for more information. */ + def getFallbackConf: SQLConf = fallbackConf.get() + + private lazy val existingConf = new ThreadLocal[SQLConf] { + override def initialValue: SQLConf = null + } + + def withExistingConf[T](conf: SQLConf)(f: => T): T = { + val old = existingConf.get() + existingConf.set(conf) + try { + f + } finally { + if (old != null) { + existingConf.set(old) + } else { + existingConf.remove() + } + } + } + + /** + * Defines a getter that returns the SQLConf within scope. + * See [[get]] for more information. + */ + private val confGetter = new AtomicReference[() => SQLConf](() => fallbackConf.get()) + + /** + * Sets the active config object within the current scope. + * See [[get]] for more information. + */ + def setSQLConfGetter(getter: () => SQLConf): Unit = { + confGetter.set(getter) + } + + /** + * Returns the active config object within the current scope. If there is an active SparkSession, + * the proper SQLConf associated with the thread's active session is used. If it's called from + * tasks in the executor side, a SQLConf will be created from job local properties, which are set + * and propagated from the driver side, unless a `SQLConf` has been set in the scope by + * `withExistingConf` as done for propagating SQLConf for operations performed on RDDs created + * from DataFrames. + * + * The way this works is a little bit convoluted, due to the fact that config was added initially + * only for physical plans (and as a result not in sql/catalyst module). + * + * The first time a SparkSession is instantiated, we set the [[confGetter]] to return the + * active SparkSession's config. If there is no active SparkSession, it returns using the thread + * local [[fallbackConf]]. The reason [[fallbackConf]] is a thread local (rather than just a conf) + * is to support setting different config options for different threads so we can potentially + * run tests in parallel. At the time this feature was implemented, this was a no-op since we + * run unit tests (that does not involve SparkSession) in serial order. + */ + def get: SQLConf = { + if (TaskContext.get != null) { + val conf = existingConf.get() + if (conf != null) { + conf + } else { + new ReadOnlySQLConf(TaskContext.get()) + } + } else { + val isSchedulerEventLoopThread = SparkContext.getActive + .flatMap { sc => Option(sc.dagScheduler) } + .map(_.eventProcessLoop.eventThread) + .exists(_.getId == Thread.currentThread().getId) + if (isSchedulerEventLoopThread) { + // DAGScheduler event loop thread does not have an active SparkSession, the `confGetter` + // will return `fallbackConf` which is unexpected. Here we require the caller to get the + // conf within `withExistingConf`, otherwise fail the query. + val conf = existingConf.get() + if (conf != null) { + conf + } else if (Utils.isTesting) { + throw QueryExecutionErrors.cannotGetSQLConfInSchedulerEventLoopThreadError() + } else { + confGetter.get()() + } + } else { + val conf = existingConf.get() + if (conf != null) { + conf + } else { + confGetter.get()() + } + } + } + } + + val ANALYZER_MAX_ITERATIONS = buildConf("spark.sql.analyzer.maxIterations") + .internal() + .doc("The max number of iterations the analyzer runs.") + .version("3.0.0") + .intConf + .createWithDefault(100) + + val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules") + .doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " + + "specified by their rule names and separated by comma. It is not guaranteed that all the " + + "rules in this configuration will eventually be excluded, as some rules are necessary " + + "for correctness. The optimizer will log the rules that have indeed been excluded.") + .version("2.4.0") + .stringConf + .createOptional + + val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") + .internal() + .doc("The max number of iterations the optimizer runs.") + .version("2.0.0") + .intConf + .createWithDefault(100) + + val OPTIMIZER_INSET_CONVERSION_THRESHOLD = + buildConf("spark.sql.optimizer.inSetConversionThreshold") + .internal() + .doc("The threshold of set size for InSet conversion.") + .version("2.0.0") + .intConf + .createWithDefault(10) + + val OPTIMIZER_INSET_SWITCH_THRESHOLD = + buildConf("spark.sql.optimizer.inSetSwitchThreshold") + .internal() + .doc("Configures the max set size in InSet for which Spark will generate code with " + + "switch statements. This is applicable only to bytes, shorts, ints, dates.") + .version("3.0.0") + .intConf + .checkValue(threshold => threshold >= 0 && threshold <= 600, "The max set size " + + "for using switch statements in InSet must be non-negative and less than or equal to 600") + .createWithDefault(400) + + val PLAN_CHANGE_LOG_LEVEL = buildConf("spark.sql.planChangeLog.level") + .internal() + .doc("Configures the log level for logging the change from the original plan to the new " + + "plan after a rule or batch is applied. The value can be 'trace', 'debug', 'info', " + + "'warn', or 'error'. The default log level is 'trace'.") + .version("3.1.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValue(logLevel => Set("TRACE", "DEBUG", "INFO", "WARN", "ERROR").contains(logLevel), + "Invalid value for 'spark.sql.planChangeLog.level'. Valid values are " + + "'trace', 'debug', 'info', 'warn' and 'error'.") + .createWithDefault("trace") + + val PLAN_CHANGE_LOG_RULES = buildConf("spark.sql.planChangeLog.rules") + .internal() + .doc("Configures a list of rules for logging plan changes, in which the rules are " + + "specified by their rule names and separated by comma.") + .version("3.1.0") + .stringConf + .createOptional + + val PLAN_CHANGE_LOG_BATCHES = buildConf("spark.sql.planChangeLog.batches") + .internal() + .doc("Configures a list of batches for logging plan changes, in which the batches " + + "are specified by their batch names and separated by comma.") + .version("3.1.0") + .stringConf + .createOptional + + val DYNAMIC_PARTITION_PRUNING_ENABLED = + buildConf("spark.sql.optimizer.dynamicPartitionPruning.enabled") + .doc("When true, we will generate predicate for partition column when it's used as join key") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val DYNAMIC_PARTITION_PRUNING_USE_STATS = + buildConf("spark.sql.optimizer.dynamicPartitionPruning.useStats") + .internal() + .doc("When true, distinct count statistics will be used for computing the data size of the " + + "partitioned table after dynamic partition pruning, in order to evaluate if it is worth " + + "adding an extra subquery as the pruning filter if broadcast reuse is not applicable.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO = + buildConf("spark.sql.optimizer.dynamicPartitionPruning.fallbackFilterRatio") + .internal() + .doc("When statistics are not available or configured not to be used, this config will be " + + "used as the fallback filter ratio for computing the data size of the partitioned table " + + "after dynamic partition pruning, in order to evaluate if it is worth adding an extra " + + "subquery as the pruning filter if broadcast reuse is not applicable.") + .version("3.0.0") + .doubleConf + .createWithDefault(0.5) + + val DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY = + buildConf("spark.sql.optimizer.dynamicPartitionPruning.reuseBroadcastOnly") + .internal() + .doc("When true, dynamic partition pruning will only apply when the broadcast exchange of " + + "a broadcast hash join operation can be reused as the dynamic pruning filter.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED = + buildConf("spark.sql.optimizer.runtimeFilter.semiJoinReduction.enabled") + .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " + + "to insert a semi join in the other side to reduce the amount of shuffle data") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + + val RUNTIME_FILTER_NUMBER_THRESHOLD = + buildConf("spark.sql.optimizer.runtimeFilter.number.threshold") + .doc("The total number of injected runtime filters (non-DPP) for a single " + + "query. This is to prevent driver OOMs with too many Bloom filters") + .version("3.3.0") + .intConf + .checkValue(threshold => threshold >= 0, "The threshold should be >= 0") + .createWithDefault(10) + + lazy val RUNTIME_BLOOM_FILTER_ENABLED = + buildConf("spark.sql.optimizer.runtime.bloomFilter.enabled") + .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " + + "to insert a bloom filter in the other side to reduce the amount of shuffle data") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + + val RUNTIME_BLOOM_FILTER_THRESHOLD = + buildConf("spark.sql.optimizer.runtime.bloomFilter.threshold") + .doc("Size threshold of the bloom filter creation side plan. Estimated size needs to be " + + "under this value to try to inject bloom filter") + .version("3.3.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("10MB") + + val RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD = + buildConf("spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizethreshold") + .doc("Byte size threshold of the Bloom filter application side plan's aggregated scan " + + "size. Aggregated scan byte size of the Bloom filter application side needs to be over " + + "this value to inject a bloom filter") + .version("3.3.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("10GB") + + val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") + .doc("When set to true Spark SQL will automatically select a compression codec for each " + + "column based on statistics of the data.") + .version("1.0.1") + .booleanConf + .createWithDefault(true) + + val COLUMN_BATCH_SIZE = buildConf("spark.sql.inMemoryColumnarStorage.batchSize") + .doc("Controls the size of batches for columnar caching. Larger batch sizes can improve " + + "memory utilization and compression, but risk OOMs when caching data.") + .version("1.1.1") + .intConf + .createWithDefault(10000) + + val IN_MEMORY_PARTITION_PRUNING = + buildConf("spark.sql.inMemoryColumnarStorage.partitionPruning") + .internal() + .doc("When true, enable partition pruning for in-memory columnar tables.") + .version("1.2.0") + .booleanConf + .createWithDefault(true) + + val IN_MEMORY_TABLE_SCAN_STATISTICS_ENABLED = + buildConf("spark.sql.inMemoryTableScanStatistics.enable") + .internal() + .doc("When true, enable in-memory table scan accumulators.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val CACHE_VECTORIZED_READER_ENABLED = + buildConf("spark.sql.inMemoryColumnarStorage.enableVectorizedReader") + .doc("Enables vectorized reader for columnar caching.") + .version("2.3.1") + .booleanConf + .createWithDefault(true) + + val COLUMN_VECTOR_OFFHEAP_ENABLED = + buildConf("spark.sql.columnVector.offheap.enabled") + .internal() + .doc("When true, use OffHeapColumnVector in ColumnarBatch.") + .version("2.3.0") + .booleanConf + .createWithDefault(false) + + val PREFER_SORTMERGEJOIN = buildConf("spark.sql.join.preferSortMergeJoin") + .internal() + .doc("When true, prefer sort merge join over shuffled hash join. " + + "Sort merge join consumes less memory than shuffled hash join and it works efficiently " + + "when both join tables are large. On the other hand, shuffled hash join can improve " + + "performance (e.g., of full outer joins) when one of join tables is much smaller.") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val RADIX_SORT_ENABLED = buildConf("spark.sql.sort.enableRadixSort") + .internal() + .doc("When true, enable use of radix sort when possible. Radix sort is much faster but " + + "requires additional memory to be reserved up-front. The memory overhead may be " + + "significant when sorting very small rows (up to 50% more in this case).") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val AUTO_BROADCASTJOIN_THRESHOLD = buildConf("spark.sql.autoBroadcastJoinThreshold") + .doc("Configures the maximum size in bytes for a table that will be broadcast to all worker " + + "nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " + + "Note that currently statistics are only supported for Hive Metastore tables where the " + + "command `ANALYZE TABLE COMPUTE STATISTICS noscan` has been " + + "run, and file-based data source tables where the statistics are computed directly on " + + "the files of data.") + .version("1.1.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("10MB") + + val LIMIT_SCALE_UP_FACTOR = buildConf("spark.sql.limit.scaleUpFactor") + .internal() + .doc("Minimal increase rate in number of partitions between attempts when executing a take " + + "on a query. Higher values lead to more partitions read. Lower values might lead to " + + "longer execution times as more jobs will be run") + .version("2.1.1") + .intConf + .createWithDefault(4) + + val ADVANCED_PARTITION_PREDICATE_PUSHDOWN = + buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled") + .internal() + .doc("When true, advanced partition predicate pushdown into Hive metastore is enabled.") + .version("2.3.0") + .booleanConf + .createWithDefault(true) + + val LEAF_NODE_DEFAULT_PARALLELISM = buildConf("spark.sql.leafNodeDefaultParallelism") + .doc("The default parallelism of Spark SQL leaf nodes that produce data, such as the file " + + "scan node, the local data scan node, the range node, etc. The default value of this " + + "config is 'SparkContext#defaultParallelism'.") + .version("3.2.0") + .intConf + .checkValue(_ > 0, "The value of spark.sql.leafNodeDefaultParallelism must be positive.") + .createOptional + + val SHUFFLE_PARTITIONS = buildConf("spark.sql.shuffle.partitions") + .doc("The default number of partitions to use when shuffling data for joins or aggregations. " + + "Note: For structured streaming, this configuration cannot be changed between query " + + "restarts from the same checkpoint location.") + .version("1.1.0") + .intConf + .checkValue(_ > 0, "The value of spark.sql.shuffle.partitions must be positive") + .createWithDefault(200) + + val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = + buildConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize") + .internal() + .doc("(Deprecated since Spark 3.0)") + .version("1.6.0") + .bytesConf(ByteUnit.BYTE) + .checkValue(_ > 0, "advisoryPartitionSizeInBytes must be positive") + .createWithDefaultString("64MB") + + val ADAPTIVE_EXECUTION_ENABLED = buildConf("spark.sql.adaptive.enabled") + .doc("When true, enable adaptive query execution, which re-optimizes the query plan in the " + + "middle of query execution, based on accurate runtime statistics.") + .version("1.6.0") + .booleanConf + .createWithDefault(true) + + val ADAPTIVE_EXECUTION_FORCE_APPLY = buildConf("spark.sql.adaptive.forceApply") + .internal() + .doc("Adaptive query execution is skipped when the query does not have exchanges or " + + "sub-queries. By setting this config to true (together with " + + s"'${ADAPTIVE_EXECUTION_ENABLED.key}' set to true), Spark will force apply adaptive query " + + "execution for all supported queries.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val ADAPTIVE_EXECUTION_LOG_LEVEL = buildConf("spark.sql.adaptive.logLevel") + .internal() + .doc("Configures the log level for adaptive execution logging of plan changes. The value " + + "can be 'trace', 'debug', 'info', 'warn', or 'error'. The default log level is 'debug'.") + .version("3.0.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(Set("TRACE", "DEBUG", "INFO", "WARN", "ERROR")) + .createWithDefault("debug") + + val ADVISORY_PARTITION_SIZE_IN_BYTES = + buildConf("spark.sql.adaptive.advisoryPartitionSizeInBytes") + .doc("The advisory size in bytes of the shuffle partition during adaptive optimization " + + s"(when ${ADAPTIVE_EXECUTION_ENABLED.key} is true). It takes effect when Spark " + + "coalesces small shuffle partitions or splits skewed shuffle partition.") + .version("3.0.0") + .fallbackConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) + + val COALESCE_PARTITIONS_ENABLED = + buildConf("spark.sql.adaptive.coalescePartitions.enabled") + .doc(s"When true and '${ADAPTIVE_EXECUTION_ENABLED.key}' is true, Spark will coalesce " + + "contiguous shuffle partitions according to the target size (specified by " + + s"'${ADVISORY_PARTITION_SIZE_IN_BYTES.key}'), to avoid too many small tasks.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val COALESCE_PARTITIONS_PARALLELISM_FIRST = + buildConf("spark.sql.adaptive.coalescePartitions.parallelismFirst") + .doc("When true, Spark does not respect the target size specified by " + + s"'${ADVISORY_PARTITION_SIZE_IN_BYTES.key}' (default 64MB) when coalescing contiguous " + + "shuffle partitions, but adaptively calculate the target size according to the default " + + "parallelism of the Spark cluster. The calculated size is usually smaller than the " + + "configured target size. This is to maximize the parallelism and avoid performance " + + "regression when enabling adaptive query execution. It's recommended to set this config " + + "to false and respect the configured target size.") + .version("3.2.0") + .booleanConf + .createWithDefault(true) + + val COALESCE_PARTITIONS_MIN_PARTITION_SIZE = + buildConf("spark.sql.adaptive.coalescePartitions.minPartitionSize") + .doc("The minimum size of shuffle partitions after coalescing. This is useful when the " + + "adaptively calculated target size is too small during partition coalescing.") + .version("3.2.0") + .bytesConf(ByteUnit.BYTE) + .checkValue(_ > 0, "minPartitionSize must be positive") + .createWithDefaultString("1MB") + + val COALESCE_PARTITIONS_MIN_PARTITION_NUM = + buildConf("spark.sql.adaptive.coalescePartitions.minPartitionNum") + .internal() + .doc("(deprecated) The suggested (not guaranteed) minimum number of shuffle partitions " + + "after coalescing. If not set, the default value is the default parallelism of the " + + "Spark cluster. This configuration only has an effect when " + + s"'${ADAPTIVE_EXECUTION_ENABLED.key}' and " + + s"'${COALESCE_PARTITIONS_ENABLED.key}' are both true.") + .version("3.0.0") + .intConf + .checkValue(_ > 0, "The minimum number of partitions must be positive.") + .createOptional + + val COALESCE_PARTITIONS_INITIAL_PARTITION_NUM = + buildConf("spark.sql.adaptive.coalescePartitions.initialPartitionNum") + .doc("The initial number of shuffle partitions before coalescing. If not set, it equals to " + + s"${SHUFFLE_PARTITIONS.key}. This configuration only has an effect when " + + s"'${ADAPTIVE_EXECUTION_ENABLED.key}' and '${COALESCE_PARTITIONS_ENABLED.key}' " + + "are both true.") + .version("3.0.0") + .intConf + .checkValue(_ > 0, "The initial number of partitions must be positive.") + .createOptional + + val FETCH_SHUFFLE_BLOCKS_IN_BATCH = + buildConf("spark.sql.adaptive.fetchShuffleBlocksInBatch") + .internal() + .doc("Whether to fetch the contiguous shuffle blocks in batch. Instead of fetching blocks " + + "one by one, fetching contiguous shuffle blocks for the same map task in batch can " + + "reduce IO and improve performance. Note, multiple contiguous blocks exist in single " + + s"fetch request only happen when '${ADAPTIVE_EXECUTION_ENABLED.key}' and " + + s"'${COALESCE_PARTITIONS_ENABLED.key}' are both true. This feature also depends " + + "on a relocatable serializer, the concatenation support codec in use, the new version " + + "shuffle fetch protocol and io encryption is disabled.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val LOCAL_SHUFFLE_READER_ENABLED = + buildConf("spark.sql.adaptive.localShuffleReader.enabled") + .doc(s"When true and '${ADAPTIVE_EXECUTION_ENABLED.key}' is true, Spark tries to use local " + + "shuffle reader to read the shuffle data when the shuffle partitioning is not needed, " + + "for example, after converting sort-merge join to broadcast-hash join.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val SKEW_JOIN_ENABLED = + buildConf("spark.sql.adaptive.skewJoin.enabled") + .doc(s"When true and '${ADAPTIVE_EXECUTION_ENABLED.key}' is true, Spark dynamically " + + "handles skew in shuffled join (sort-merge and shuffled hash) by splitting (and " + + "replicating if needed) skewed partitions.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val SKEW_JOIN_SKEWED_PARTITION_FACTOR = + buildConf("spark.sql.adaptive.skewJoin.skewedPartitionFactor") + .doc("A partition is considered as skewed if its size is larger than this factor " + + "multiplying the median partition size and also larger than " + + "'spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes'") + .version("3.0.0") + .intConf + .checkValue(_ >= 0, "The skew factor cannot be negative.") + .createWithDefault(5) + + val SKEW_JOIN_SKEWED_PARTITION_THRESHOLD = + buildConf("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes") + .doc("A partition is considered as skewed if its size in bytes is larger than this " + + s"threshold and also larger than '${SKEW_JOIN_SKEWED_PARTITION_FACTOR.key}' " + + "multiplying the median partition size. Ideally this config should be set larger " + + s"than '${ADVISORY_PARTITION_SIZE_IN_BYTES.key}'.") + .version("3.0.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("256MB") + + val NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN = + buildConf("spark.sql.adaptive.nonEmptyPartitionRatioForBroadcastJoin") + .internal() + .doc("The relation with a non-empty partition ratio lower than this config will not be " + + "considered as the build side of a broadcast-hash join in adaptive execution regardless " + + "of its size.This configuration only has an effect when " + + s"'${ADAPTIVE_EXECUTION_ENABLED.key}' is true.") + .version("3.0.0") + .doubleConf + .checkValue(_ >= 0, "The non-empty partition ratio must be positive number.") + .createWithDefault(0.2) + + val ADAPTIVE_OPTIMIZER_EXCLUDED_RULES = + buildConf("spark.sql.adaptive.optimizer.excludedRules") + .doc("Configures a list of rules to be disabled in the adaptive optimizer, in which the " + + "rules are specified by their rule names and separated by comma. The optimizer will log " + + "the rules that have indeed been excluded.") + .version("3.1.0") + .stringConf + .createOptional + + val ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD = + buildConf("spark.sql.adaptive.autoBroadcastJoinThreshold") + .doc("Configures the maximum size in bytes for a table that will be broadcast to all " + + "worker nodes when performing a join. By setting this value to -1 broadcasting can be " + + s"disabled. The default value is same with ${AUTO_BROADCASTJOIN_THRESHOLD.key}. " + + "Note that, this config is used only in adaptive framework.") + .version("3.2.0") + .bytesConf(ByteUnit.BYTE) + .createOptional + + val ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD = + buildConf("spark.sql.adaptive.maxShuffledHashJoinLocalMapThreshold") + .doc("Configures the maximum size in bytes per partition that can be allowed to build " + + "local hash map. If this value is not smaller than " + + s"${ADVISORY_PARTITION_SIZE_IN_BYTES.key} and all the partition size are not larger " + + "than this config, join selection prefer to use shuffled hash join instead of " + + s"sort merge join regardless of the value of ${PREFER_SORTMERGEJOIN.key}.") + .version("3.2.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(0L) + + val ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED = + buildConf("spark.sql.adaptive.optimizeSkewsInRebalancePartitions.enabled") + .doc(s"When true and '${ADAPTIVE_EXECUTION_ENABLED.key}' is true, Spark will optimize the " + + "skewed shuffle partitions in RebalancePartitions and split them to smaller ones " + + s"according to the target size (specified by '${ADVISORY_PARTITION_SIZE_IN_BYTES.key}'), " + + "to avoid data skew.") + .version("3.2.0") + .booleanConf + .createWithDefault(true) + + val ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS = + buildConf("spark.sql.adaptive.customCostEvaluatorClass") + .doc("The custom cost evaluator class to be used for adaptive execution. If not being set," + + " Spark will use its own SimpleCostEvaluator by default.") + .version("3.2.0") + .stringConf + .createOptional + + val SUBEXPRESSION_ELIMINATION_ENABLED = + buildConf("spark.sql.subexpressionElimination.enabled") + .internal() + .doc("When true, common subexpressions will be eliminated.") + .version("1.6.0") + .booleanConf + .createWithDefault(true) + + val SUBEXPRESSION_ELIMINATION_CACHE_MAX_ENTRIES = + buildConf("spark.sql.subexpressionElimination.cache.maxEntries") + .internal() + .doc("The maximum entries of the cache used for interpreted subexpression elimination.") + .version("3.1.0") + .intConf + .checkValue(_ >= 0, "The maximum must not be negative") + .createWithDefault(100) + + val CASE_SENSITIVE = buildConf("spark.sql.caseSensitive") + .internal() + .doc("Whether the query analyzer should be case sensitive or not. " + + "Default to case insensitive. It is highly discouraged to turn on case sensitive mode.") + .version("1.4.0") + .booleanConf + .createWithDefault(false) + + val CONSTRAINT_PROPAGATION_ENABLED = buildConf("spark.sql.constraintPropagation.enabled") + .internal() + .doc("When true, the query optimizer will infer and propagate data constraints in the query " + + "plan to optimize them. Constraint propagation can sometimes be computationally expensive " + + "for certain kinds of query plans (such as those with a large number of predicates and " + + "aliases) which might negatively impact overall runtime.") + .version("2.2.0") + .booleanConf + .createWithDefault(true) + + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") + .internal() + .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + + "parser. The default is false since Spark 2.0. Setting it to true can restore the behavior " + + "prior to Spark 2.0.") + .version("2.2.1") + .booleanConf + .createWithDefault(false) + + val FILE_COMPRESSION_FACTOR = buildConf("spark.sql.sources.fileCompressionFactor") + .internal() + .doc("When estimating the output data size of a table scan, multiply the file size with this " + + "factor as the estimated data size, in case the data is compressed in the file and lead to" + + " a heavily underestimated result.") + .version("2.3.1") + .doubleConf + .checkValue(_ > 0, "the value of fileCompressionFactor must be greater than 0") + .createWithDefault(1.0) + + val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") + .doc("When true, the Parquet data source merges schemas collected from all data files, " + + "otherwise the schema is picked from the summary file or a random data file " + + "if no summary file is available.") + .version("1.5.0") + .booleanConf + .createWithDefault(false) + + val PARQUET_SCHEMA_RESPECT_SUMMARIES = buildConf("spark.sql.parquet.respectSummaryFiles") + .doc("When true, we make assumption that all part-files of Parquet are consistent with " + + "summary files and we will ignore them when merging schema. Otherwise, if this is " + + "false, which is the default, we will merge all part-files. This should be considered " + + "as expert-only option, and shouldn't be enabled before knowing what it means exactly.") + .version("1.5.0") + .booleanConf + .createWithDefault(false) + + val PARQUET_BINARY_AS_STRING = buildConf("spark.sql.parquet.binaryAsString") + .doc("Some other Parquet-producing systems, in particular Impala and older versions of " + + "Spark SQL, do not differentiate between binary data and strings when writing out the " + + "Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide " + + "compatibility with these systems.") + .version("1.1.1") + .booleanConf + .createWithDefault(false) + + val PARQUET_INT96_AS_TIMESTAMP = buildConf("spark.sql.parquet.int96AsTimestamp") + .doc("Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. " + + "Spark would also store Timestamp as INT96 because we need to avoid precision lost of the " + + "nanoseconds field. This flag tells Spark SQL to interpret INT96 data as a timestamp to " + + "provide compatibility with these systems.") + .version("1.3.0") + .booleanConf + .createWithDefault(true) + + val PARQUET_INT96_TIMESTAMP_CONVERSION = buildConf("spark.sql.parquet.int96TimestampConversion") + .doc("This controls whether timestamp adjustments should be applied to INT96 data when " + + "converting to timestamps, for data written by Impala. This is necessary because Impala " + + "stores INT96 data with a different timezone offset than Hive & Spark.") + .version("2.3.0") + .booleanConf + .createWithDefault(false) + + object ParquetOutputTimestampType extends Enumeration { + val INT96, TIMESTAMP_MICROS, TIMESTAMP_MILLIS = Value + } + + val PARQUET_OUTPUT_TIMESTAMP_TYPE = buildConf("spark.sql.parquet.outputTimestampType") + .doc("Sets which Parquet timestamp type to use when Spark writes data to Parquet files. " + + "INT96 is a non-standard but commonly used timestamp type in Parquet. TIMESTAMP_MICROS " + + "is a standard timestamp type in Parquet, which stores number of microseconds from the " + + "Unix epoch. TIMESTAMP_MILLIS is also standard, but with millisecond precision, which " + + "means Spark has to truncate the microsecond portion of its timestamp value.") + .version("2.3.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(ParquetOutputTimestampType.values.map(_.toString)) + .createWithDefault(ParquetOutputTimestampType.INT96.toString) + + val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or " + + "`parquet.compression` is specified in the table-specific options/properties, the " + + "precedence would be `compression`, `parquet.compression`, " + + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + + "snappy, gzip, lzo, brotli, lz4, zstd.") + .version("1.1.1") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo", "lz4", "brotli", "zstd")) + .createWithDefault("snappy") + + val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") + .doc("Enables Parquet filter push-down optimization when set to true.") + .version("1.2.0") + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_DATE_ENABLED = buildConf("spark.sql.parquet.filterPushdown.date") + .doc("If true, enables Parquet filter push-down optimization for Date. " + + s"This configuration only has an effect when '${PARQUET_FILTER_PUSHDOWN_ENABLED.key}' is " + + "enabled.") + .version("2.4.0") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.timestamp") + .doc("If true, enables Parquet filter push-down optimization for Timestamp. " + + s"This configuration only has an effect when '${PARQUET_FILTER_PUSHDOWN_ENABLED.key}' is " + + "enabled and Timestamp stored as TIMESTAMP_MICROS or TIMESTAMP_MILLIS type.") + .version("2.4.0") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.decimal") + .doc("If true, enables Parquet filter push-down optimization for Decimal. " + + s"This configuration only has an effect when '${PARQUET_FILTER_PUSHDOWN_ENABLED.key}' is " + + "enabled.") + .version("2.4.0") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.string.startsWith") + .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + + s"This configuration only has an effect when '${PARQUET_FILTER_PUSHDOWN_ENABLED.key}' is " + + "enabled.") + .version("2.4.0") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD = + buildConf("spark.sql.parquet.pushdown.inFilterThreshold") + .doc("For IN predicate, Parquet filter will push-down a set of OR clauses if its " + + "number of values not exceeds this threshold. Otherwise, Parquet filter will push-down " + + "a value greater than or equal to its minimum value and less than or equal to " + + "its maximum value. By setting this value to 0 this feature can be disabled. " + + s"This configuration only has an effect when '${PARQUET_FILTER_PUSHDOWN_ENABLED.key}' is " + + "enabled.") + .version("2.4.0") + .internal() + .intConf + .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") + .createWithDefault(10) + + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") + .doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " + + "values will be written in Apache Parquet's fixed-length byte array format, which other " + + "systems such as Apache Hive and Apache Impala use. If false, the newer format in Parquet " + + "will be used. For example, decimals will be written in int-based format. If Parquet " + + "output is intended for use with systems that do not support this newer format, set to true.") + .version("1.6.0") + .booleanConf + .createWithDefault(false) + + val PARQUET_OUTPUT_COMMITTER_CLASS = buildConf("spark.sql.parquet.output.committer.class") + .doc("The output committer class used by Parquet. The specified class needs to be a " + + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + + "of org.apache.parquet.hadoop.ParquetOutputCommitter. If it is not, then metadata " + + "summaries will never be created, irrespective of the value of " + + "parquet.summary.metadata.level") + .version("1.5.0") + .internal() + .stringConf + .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") + + val PARQUET_VECTORIZED_READER_ENABLED = + buildConf("spark.sql.parquet.enableVectorizedReader") + .doc("Enables vectorized parquet decoding.") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val PARQUET_RECORD_FILTER_ENABLED = buildConf("spark.sql.parquet.recordLevelFilter.enabled") + .doc("If true, enables Parquet's native record-level filtering using the pushed down " + + "filters. " + + s"This configuration only has an effect when '${PARQUET_FILTER_PUSHDOWN_ENABLED.key}' " + + "is enabled and the vectorized reader is not used. You can ensure the vectorized reader " + + s"is not used by setting '${PARQUET_VECTORIZED_READER_ENABLED.key}' to false.") + .version("2.3.0") + .booleanConf + .createWithDefault(false) + + val PARQUET_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.parquet.columnarReaderBatchSize") + .doc("The number of rows to include in a parquet vectorized reader batch. The number should " + + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") + .version("2.4.0") + .intConf + .createWithDefault(4096) + + val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") + .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + + "`orc.compress` is specified in the table-specific options/properties, the precedence " + + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, zlib, lzo, zstd, lz4.") + .version("2.3.0") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValues(Set("none", "uncompressed", "snappy", "zlib", "lzo", "zstd", "lz4")) + .createWithDefault("snappy") + + val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl") + .doc("When native, use the native version of ORC support instead of the ORC library in Hive. " + + "It is 'hive' by default prior to Spark 2.4.") + .version("2.3.0") + .internal() + .stringConf + .checkValues(Set("hive", "native")) + .createWithDefault("native") + + val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") + .doc("Enables vectorized orc decoding.") + .version("2.3.0") + .booleanConf + .createWithDefault(true) + + val ORC_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.orc.columnarReaderBatchSize") + .doc("The number of rows to include in a orc vectorized reader batch. The number should " + + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") + .version("2.4.0") + .intConf + .createWithDefault(4096) + + val ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED = + buildConf("spark.sql.orc.enableNestedColumnVectorizedReader") + .doc("Enables vectorized orc decoding for nested column.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") + .doc("When true, enable filter pushdown for ORC files.") + .version("1.4.0") + .booleanConf + .createWithDefault(true) + + val ORC_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.orc.mergeSchema") + .doc("When true, the Orc data source merges schemas collected from all data files, " + + "otherwise the schema is picked from a random data file.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") + .doc("When true, check all the partition paths under the table\'s root directory " + + "when reading data stored in HDFS. This configuration will be deprecated in the future " + + s"releases and replaced by ${SPARK_IGNORE_MISSING_FILES.key}.") + .version("1.4.0") + .booleanConf + .createWithDefault(false) + + val HIVE_METASTORE_PARTITION_PRUNING = + buildConf("spark.sql.hive.metastorePartitionPruning") + .doc("When true, some predicates will be pushed down into the Hive metastore so that " + + "unmatching partitions can be eliminated earlier.") + .version("1.5.0") + .booleanConf + .createWithDefault(true) + + val HIVE_METASTORE_PARTITION_PRUNING_INSET_THRESHOLD = + buildConf("spark.sql.hive.metastorePartitionPruningInSetThreshold") + .doc("The threshold of set size for InSet predicate when pruning partitions through Hive " + + "Metastore. When the set size exceeds the threshold, we rewrite the InSet predicate " + + "to be greater than or equal to the minimum value in set and less than or equal to the " + + "maximum value in set. Larger values may cause Hive Metastore stack overflow. But for " + + "InSet inside Not with values exceeding the threshold, we won't push it to Hive Metastore." + ) + .version("3.1.0") + .internal() + .intConf + .checkValue(_ > 0, "The value of metastorePartitionPruningInSetThreshold must be positive") + .createWithDefault(1000) + + val HIVE_MANAGE_FILESOURCE_PARTITIONS = + buildConf("spark.sql.hive.manageFilesourcePartitions") + .doc("When true, enable metastore partition management for file source tables as well. " + + "This includes both datasource and converted Hive tables. When partition management " + + "is enabled, datasource tables store partition in the Hive metastore, and use the " + + s"metastore to prune partitions during query planning when " + + s"${HIVE_METASTORE_PARTITION_PRUNING.key} is set to true.") + .version("2.1.1") + .booleanConf + .createWithDefault(true) + + val HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE = + buildConf("spark.sql.hive.filesourcePartitionFileCacheSize") + .doc("When nonzero, enable caching of partition file metadata in memory. All tables share " + + "a cache that can use up to specified num bytes for file metadata. This conf only " + + "has an effect when hive filesource partition management is enabled.") + .version("2.1.1") + .longConf + .createWithDefault(250 * 1024 * 1024) + + object HiveCaseSensitiveInferenceMode extends Enumeration { + val INFER_AND_SAVE, INFER_ONLY, NEVER_INFER = Value + } + + val HIVE_CASE_SENSITIVE_INFERENCE = buildConf("spark.sql.hive.caseSensitiveInferenceMode") + .internal() + .doc("Sets the action to take when a case-sensitive schema cannot be read from a Hive Serde " + + "table's properties when reading the table with Spark native data sources. Valid options " + + "include INFER_AND_SAVE (infer the case-sensitive schema from the underlying data files " + + "and write it back to the table properties), INFER_ONLY (infer the schema but don't " + + "attempt to write it to the table properties) and NEVER_INFER (the default mode-- fallback " + + "to using the case-insensitive metastore schema instead of inferring).") + .version("2.1.1") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) + .createWithDefault(HiveCaseSensitiveInferenceMode.NEVER_INFER.toString) + + val HIVE_TABLE_PROPERTY_LENGTH_THRESHOLD = + buildConf("spark.sql.hive.tablePropertyLengthThreshold") + .internal() + .doc("The maximum length allowed in a single cell when storing Spark-specific information " + + "in Hive's metastore as table properties. Currently it covers 2 things: the schema's " + + "JSON string, the histogram of column statistics.") + .version("3.2.0") + .intConf + .createOptional + + val OPTIMIZER_METADATA_ONLY = buildConf("spark.sql.optimizer.metadataOnly") + .internal() + .doc("When true, enable the metadata-only query optimization that use the table's metadata " + + "to produce the partition columns instead of table scans. It applies when all the columns " + + "scanned are partition columns and the query has an aggregate operator that satisfies " + + "distinct semantics. By default the optimization is disabled, and deprecated as of Spark " + + "3.0 since it may return incorrect results when the files are empty, see also SPARK-26709." + + "It will be removed in the future releases. If you must use, use 'SparkSessionExtensions' " + + "instead to inject it as a custom rule.") + .version("2.1.1") + .booleanConf + .createWithDefault(false) + + val COLUMN_NAME_OF_CORRUPT_RECORD = buildConf("spark.sql.columnNameOfCorruptRecord") + .doc("The name of internal column for storing raw/un-parsed JSON and CSV records that fail " + + "to parse.") + .version("1.2.0") + .stringConf + .createWithDefault("_corrupt_record") + + val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout") + .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") + .version("1.3.0") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString(s"${5 * 60}") + + // This is only used for the thriftserver + val THRIFTSERVER_POOL = buildConf("spark.sql.thriftserver.scheduler.pool") + .doc("Set a Fair Scheduler pool for a JDBC client session.") + .version("1.1.1") + .stringConf + .createOptional + + val THRIFTSERVER_INCREMENTAL_COLLECT = + buildConf("spark.sql.thriftServer.incrementalCollect") + .internal() + .doc("When true, enable incremental collection for execution in Thrift Server.") + .version("2.0.3") + .booleanConf + .createWithDefault(false) + + val THRIFTSERVER_FORCE_CANCEL = + buildConf("spark.sql.thriftServer.interruptOnCancel") + .doc("When true, all running tasks will be interrupted if one cancels a query. " + + "When false, all running tasks will remain until finished.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + + val THRIFTSERVER_QUERY_TIMEOUT = + buildConf("spark.sql.thriftServer.queryTimeout") + .doc("Set a query duration timeout in seconds in Thrift Server. If the timeout is set to " + + "a positive value, a running query will be cancelled automatically when the timeout is " + + "exceeded, otherwise the query continues to run till completion. If timeout values are " + + "set for each statement via `java.sql.Statement.setQueryTimeout` and they are smaller " + + "than this configuration value, they take precedence. If you set this timeout and prefer " + + "to cancel the queries right away without waiting task to finish, consider enabling " + + s"${THRIFTSERVER_FORCE_CANCEL.key} together.") + .version("3.1.0") + .timeConf(TimeUnit.SECONDS) + .createWithDefault(0L) + + val THRIFTSERVER_UI_STATEMENT_LIMIT = + buildConf("spark.sql.thriftserver.ui.retainedStatements") + .doc("The number of SQL statements kept in the JDBC/ODBC web UI history.") + .version("1.4.0") + .intConf + .createWithDefault(200) + + val THRIFTSERVER_UI_SESSION_LIMIT = buildConf("spark.sql.thriftserver.ui.retainedSessions") + .doc("The number of SQL client sessions kept in the JDBC/ODBC web UI history.") + .version("1.4.0") + .intConf + .createWithDefault(200) + + // This is used to set the default data source + val DEFAULT_DATA_SOURCE_NAME = buildConf("spark.sql.sources.default") + .doc("The default data source to use in input/output.") + .version("1.3.0") + .stringConf + .createWithDefault("parquet") + + val CONVERT_CTAS = buildConf("spark.sql.hive.convertCTAS") + .internal() + .doc("When true, a table created by a Hive CTAS statement (no USING clause) " + + "without specifying any storage property will be converted to a data source table, " + + s"using the data source set by ${DEFAULT_DATA_SOURCE_NAME.key}.") + .version("2.0.0") + .booleanConf + .createWithDefault(false) + + val GATHER_FASTSTAT = buildConf("spark.sql.hive.gatherFastStats") + .internal() + .doc("When true, fast stats (number of files and total size of all files) will be gathered" + + " in parallel while repairing table partitions to avoid the sequential listing in Hive" + + " metastore.") + .version("2.0.1") + .booleanConf + .createWithDefault(true) + + val PARTITION_COLUMN_TYPE_INFERENCE = + buildConf("spark.sql.sources.partitionColumnTypeInference.enabled") + .doc("When true, automatically infer the data types for partitioned columns.") + .version("1.5.0") + .booleanConf + .createWithDefault(true) + + val BUCKETING_ENABLED = buildConf("spark.sql.sources.bucketing.enabled") + .doc("When false, we will treat bucketed table as normal table") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") + .doc("The maximum number of buckets allowed.") + .version("2.4.0") + .intConf + .checkValue(_ > 0, "the value of spark.sql.sources.bucketing.maxBuckets must be greater than 0") + .createWithDefault(100000) + + val AUTO_BUCKETED_SCAN_ENABLED = + buildConf("spark.sql.sources.bucketing.autoBucketedScan.enabled") + .doc("When true, decide whether to do bucketed scan on input tables based on query plan " + + "automatically. Do not use bucketed scan if 1. query does not have operators to utilize " + + "bucketing (e.g. join, group-by, etc), or 2. there's an exchange operator between these " + + s"operators and table scan. Note when '${BUCKETING_ENABLED.key}' is set to " + + "false, this configuration does not take any effect.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING = + buildConf("spark.sql.optimizer.canChangeCachedPlanOutputPartitioning") + .internal() + .doc("Whether to forcibly enable some optimization rules that can change the output " + + "partitioning of a cached query when executing it for caching. If it is set to true, " + + "queries may need an extra shuffle to read the cached data. This configuration is " + + "disabled by default. Currently, the optimization rules enabled by this configuration " + + s"are ${ADAPTIVE_EXECUTION_ENABLED.key} and ${AUTO_BUCKETED_SCAN_ENABLED.key}.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + + val CROSS_JOINS_ENABLED = buildConf("spark.sql.crossJoin.enabled") + .internal() + .doc("When false, we will throw an error if a query contains a cartesian product without " + + "explicit CROSS JOIN syntax.") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val ORDER_BY_ORDINAL = buildConf("spark.sql.orderByOrdinal") + .doc("When true, the ordinal numbers are treated as the position in the select list. " + + "When false, the ordinal numbers in order/sort by clause are ignored.") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val GROUP_BY_ORDINAL = buildConf("spark.sql.groupByOrdinal") + .doc("When true, the ordinal numbers in group by clauses are treated as the position " + + "in the select list. When false, the ordinal numbers are ignored.") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val GROUP_BY_ALIASES = buildConf("spark.sql.groupByAliases") + .doc("When true, aliases in a select list can be used in group by clauses. When false, " + + "an analysis exception is thrown in the case.") + .version("2.2.0") + .booleanConf + .createWithDefault(true) + + // The output committer class used by data sources. The specified class needs to be a + // subclass of org.apache.hadoop.mapreduce.OutputCommitter. + val OUTPUT_COMMITTER_CLASS = buildConf("spark.sql.sources.outputCommitterClass") + .version("1.4.0") + .internal() + .stringConf + .createOptional + + val FILE_COMMIT_PROTOCOL_CLASS = + buildConf("spark.sql.sources.commitProtocolClass") + .version("2.1.1") + .internal() + .stringConf + .createWithDefault( + "org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol") + + val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = + buildConf("spark.sql.sources.parallelPartitionDiscovery.threshold") + .doc("The maximum number of paths allowed for listing files at driver side. If the number " + + "of detected paths exceeds this value during partition discovery, it tries to list the " + + "files with another Spark distributed job. This configuration is effective only when " + + "using file-based sources such as Parquet, JSON and ORC.") + .version("1.5.0") + .intConf + .checkValue(parallel => parallel >= 0, "The maximum number of paths allowed for listing " + + "files at driver side must not be negative") + .createWithDefault(32) + + val PARALLEL_PARTITION_DISCOVERY_PARALLELISM = + buildConf("spark.sql.sources.parallelPartitionDiscovery.parallelism") + .doc("The number of parallelism to list a collection of path recursively, Set the " + + "number to prevent file listing from generating too many tasks.") + .version("2.1.1") + .internal() + .intConf + .createWithDefault(10000) + + val IGNORE_DATA_LOCALITY = + buildConf("spark.sql.sources.ignoreDataLocality") + .doc("If true, Spark will not fetch the block locations for each file on " + + "listing files. This speeds up file listing, but the scheduler cannot " + + "schedule tasks to take advantage of data locality. It can be particularly " + + "useful if data is read from a remote cluster so the scheduler could never " + + "take advantage of locality anyway.") + .version("3.0.0") + .internal() + .booleanConf + .createWithDefault(false) + + // Whether to automatically resolve ambiguity in join conditions for self-joins. + // See SPARK-6231. + val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = + buildConf("spark.sql.selfJoinAutoResolveAmbiguity") + .version("1.4.0") + .internal() + .booleanConf + .createWithDefault(true) + + val FAIL_AMBIGUOUS_SELF_JOIN_ENABLED = + buildConf("spark.sql.analyzer.failAmbiguousSelfJoin") + .doc("When true, fail the Dataset query if it contains ambiguous self-join.") + .version("3.0.0") + .internal() + .booleanConf + .createWithDefault(true) + + // Whether to retain group by columns or not in GroupedData.agg. + val DATAFRAME_RETAIN_GROUP_COLUMNS = buildConf("spark.sql.retainGroupColumns") + .version("1.4.0") + .internal() + .booleanConf + .createWithDefault(true) + + val DATAFRAME_PIVOT_MAX_VALUES = buildConf("spark.sql.pivotMaxValues") + .doc("When doing a pivot without specifying values for the pivot column this is the maximum " + + "number of (distinct) values that will be collected without error.") + .version("1.6.0") + .intConf + .createWithDefault(10000) + + val RUN_SQL_ON_FILES = buildConf("spark.sql.runSQLOnFiles") + .internal() + .doc("When true, we could use `datasource`.`path` as table in SQL query.") + .version("1.6.0") + .booleanConf + .createWithDefault(true) + + val WHOLESTAGE_CODEGEN_ENABLED = buildConf("spark.sql.codegen.wholeStage") + .internal() + .doc("When true, the whole stage (of multiple operators) will be compiled into single java" + + " method.") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME = + buildConf("spark.sql.codegen.useIdInClassName") + .internal() + .doc("When true, embed the (whole-stage) codegen stage ID into " + + "the class name of the generated class as a suffix") + .version("2.3.1") + .booleanConf + .createWithDefault(true) + + val WHOLESTAGE_MAX_NUM_FIELDS = buildConf("spark.sql.codegen.maxFields") + .internal() + .doc("The maximum number of fields (including nested fields) that will be supported before" + + " deactivating whole-stage codegen.") + .version("2.0.0") + .intConf + .createWithDefault(100) + + val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode") + .doc("This config determines the fallback behavior of several codegen generators " + + "during tests. `FALLBACK` means trying codegen first and then falling back to " + + "interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " + + "`NO_CODEGEN` skips codegen and goes interpreted path always. Note that " + + "this config works only for tests.") + .version("2.4.0") + .internal() + .stringConf + .checkValues(CodegenObjectFactoryMode.values.map(_.toString)) + .createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString) + + val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback") + .internal() + .doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" + + " fail to compile generated code") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val CODEGEN_LOGGING_MAX_LINES = buildConf("spark.sql.codegen.logging.maxLines") + .internal() + .doc("The maximum number of codegen lines to log when errors occur. Use -1 for unlimited.") + .version("2.3.0") + .intConf + .checkValue(maxLines => maxLines >= -1, "The maximum must be a positive integer, 0 to " + + "disable logging or -1 to apply no limit.") + .createWithDefault(1000) + + val WHOLESTAGE_HUGE_METHOD_LIMIT = buildConf("spark.sql.codegen.hugeMethodLimit") + .internal() + .doc("The maximum bytecode size of a single compiled Java function generated by whole-stage " + + "codegen. When the compiled function exceeds this threshold, the whole-stage codegen is " + + "deactivated for this subtree of the current query plan. The default value is 65535, which " + + "is the largest bytecode size possible for a valid Java method. When running on HotSpot, " + + s"it may be preferable to set the value to ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} " + + "to match HotSpot's implementation.") + .version("2.3.0") + .intConf + .createWithDefault(65535) + + val CODEGEN_METHOD_SPLIT_THRESHOLD = buildConf("spark.sql.codegen.methodSplitThreshold") + .internal() + .doc("The threshold of source-code splitting in the codegen. When the number of characters " + + "in a single Java function (without comment) exceeds the threshold, the function will be " + + "automatically split to multiple smaller ones. We cannot know how many bytecode will be " + + "generated, so use the code length as metric. When running on HotSpot, a function's " + + "bytecode should not go beyond 8KB, otherwise it will not be JITted; it also should not " + + "be too small, otherwise there will be many function calls.") + .version("3.0.0") + .intConf + .checkValue(threshold => threshold > 0, "The threshold must be a positive integer.") + .createWithDefault(1024) + + val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR = + buildConf("spark.sql.codegen.splitConsumeFuncByOperator") + .internal() + .doc("When true, whole stage codegen would put the logic of consuming rows of each " + + "physical operator into individual methods, instead of a single big method. This can be " + + "used to avoid oversized function that can miss the opportunity of JIT optimization.") + .version("2.3.1") + .booleanConf + .createWithDefault(true) + + val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") + .doc("The maximum number of bytes to pack into a single partition when reading files. " + + "This configuration is effective only when using file-based sources such as Parquet, JSON " + + "and ORC.") + .version("2.0.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("128MB") // parquet.block.size + + val FILES_OPEN_COST_IN_BYTES = buildConf("spark.sql.files.openCostInBytes") + .internal() + .doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" + + " the same time. This is used when putting multiple files into a partition. It's better to" + + " over estimated, then the partitions with small files will be faster than partitions with" + + " bigger files (which is scheduled first). This configuration is effective only when using" + + " file-based sources such as Parquet, JSON and ORC.") + .version("2.0.0") + .longConf + .createWithDefault(4 * 1024 * 1024) + + val FILES_MIN_PARTITION_NUM = buildConf("spark.sql.files.minPartitionNum") + .doc("The suggested (not guaranteed) minimum number of split file partitions. " + + "If not set, the default value is `spark.default.parallelism`. This configuration is " + + "effective only when using file-based sources such as Parquet, JSON and ORC.") + .version("3.1.0") + .intConf + .checkValue(v => v > 0, "The min partition number must be a positive integer.") + .createOptional + + val IGNORE_CORRUPT_FILES = buildConf("spark.sql.files.ignoreCorruptFiles") + .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + + "encountering corrupted files and the contents that have been read will still be returned. " + + "This configuration is effective only when using file-based sources such as Parquet, JSON " + + "and ORC.") + .version("2.1.1") + .booleanConf + .createWithDefault(false) + + val IGNORE_MISSING_FILES = buildConf("spark.sql.files.ignoreMissingFiles") + .doc("Whether to ignore missing files. If true, the Spark jobs will continue to run when " + + "encountering missing files and the contents that have been read will still be returned. " + + "This configuration is effective only when using file-based sources such as Parquet, JSON " + + "and ORC.") + .version("2.3.0") + .booleanConf + .createWithDefault(false) + + val MAX_RECORDS_PER_FILE = buildConf("spark.sql.files.maxRecordsPerFile") + .doc("Maximum number of records to write out to a single file. " + + "If this value is zero or negative, there is no limit.") + .version("2.2.0") + .longConf + .createWithDefault(0) + + val EXCHANGE_REUSE_ENABLED = buildConf("spark.sql.exchange.reuse") + .internal() + .doc("When true, the planner will try to find out duplicated exchanges and re-use them.") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val SUBQUERY_REUSE_ENABLED = buildConf("spark.sql.execution.reuseSubquery") + .internal() + .doc("When true, the planner will try to find out duplicated subqueries and re-use them.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val REMOVE_REDUNDANT_PROJECTS_ENABLED = buildConf("spark.sql.execution.removeRedundantProjects") + .internal() + .doc("Whether to remove redundant project exec node based on children's output and " + + "ordering requirement.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val REMOVE_REDUNDANT_SORTS_ENABLED = buildConf("spark.sql.execution.removeRedundantSorts") + .internal() + .doc("Whether to remove redundant physical sort node") + .version("2.4.8") + .booleanConf + .createWithDefault(true) + + val STATE_STORE_PROVIDER_CLASS = + buildConf("spark.sql.streaming.stateStore.providerClass") + .internal() + .doc( + "The class used to manage state data in stateful streaming queries. This class must " + + "be a subclass of StateStoreProvider, and must have a zero-arg constructor. " + + "Note: For structured streaming, this configuration cannot be changed between query " + + "restarts from the same checkpoint location.") + .version("2.3.0") + .stringConf + .createWithDefault( + "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider") + + val STATE_SCHEMA_CHECK_ENABLED = + buildConf("spark.sql.streaming.stateStore.stateSchemaCheck") + .doc("When true, Spark will validate the state schema against schema on existing state and " + + "fail query if it's incompatible.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = + buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot") + .internal() + .doc("Minimum number of state store delta files that needs to be generated before they " + + "consolidated into snapshots.") + .version("2.0.0") + .intConf + .createWithDefault(10) + + val STATE_STORE_FORMAT_VALIDATION_ENABLED = + buildConf("spark.sql.streaming.stateStore.formatValidation.enabled") + .internal() + .doc("When true, check if the data from state store is valid or not when running streaming " + + "queries. This can happen if the state store format has been changed. Note, the feature " + + "is only effective in the build-in HDFS state store provider now.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion") + .internal() + .doc("State format version used by flatMapGroupsWithState operation in a streaming query") + .version("2.4.0") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + + val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation") + .doc("The default location for storing checkpoint data for streaming queries.") + .version("2.0.0") + .stringConf + .createOptional + + val FORCE_DELETE_TEMP_CHECKPOINT_LOCATION = + buildConf("spark.sql.streaming.forceDeleteTempCheckpointLocation") + .doc("When true, enable temporary checkpoint locations force delete.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val MIN_BATCHES_TO_RETAIN = buildConf("spark.sql.streaming.minBatchesToRetain") + .internal() + .doc("The minimum number of batches that must be retained and made recoverable.") + .version("2.1.1") + .intConf + .createWithDefault(100) + + val MAX_BATCHES_TO_RETAIN_IN_MEMORY = buildConf("spark.sql.streaming.maxBatchesToRetainInMemory") + .internal() + .doc("The maximum number of batches which will be retained in memory to avoid " + + "loading from files. The value adjusts a trade-off between memory usage vs cache miss: " + + "'2' covers both success and direct failure cases, '1' covers only success case, " + + "and '0' covers extreme case - disable cache to maximize memory size of executors.") + .version("2.4.0") + .intConf + .createWithDefault(2) + + val STREAMING_MAINTENANCE_INTERVAL = + buildConf("spark.sql.streaming.stateStore.maintenanceInterval") + .internal() + .doc("The interval in milliseconds between triggering maintenance tasks in StateStore. " + + "The maintenance task executes background maintenance task in all the loaded store " + + "providers if they are still the active instances according to the coordinator. If not, " + + "inactive instances of store providers will be closed.") + .version("2.0.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(TimeUnit.MINUTES.toMillis(1)) // 1 minute + + val STATE_STORE_COMPRESSION_CODEC = + buildConf("spark.sql.streaming.stateStore.compression.codec") + .internal() + .doc("The codec used to compress delta and snapshot files generated by StateStore. " + + "By default, Spark provides four codecs: lz4, lzf, snappy, and zstd. You can also " + + "use fully qualified class names to specify the codec. Default codec is lz4.") + .version("3.1.0") + .stringConf + .createWithDefault("lz4") + + /** + * Note: this is defined in `RocksDBConf.FORMAT_VERSION`. These two places should be updated + * together. + */ + val STATE_STORE_ROCKSDB_FORMAT_VERSION = + buildConf("spark.sql.streaming.stateStore.rocksdb.formatVersion") + .internal() + .doc("Set the RocksDB format version. This will be stored in the checkpoint when starting " + + "a streaming query. The checkpoint will use this RocksDB format version in the entire " + + "lifetime of the query.") + .version("3.2.0") + .intConf + .checkValue(_ >= 0, "Must not be negative") + // 5 is the default table format version for RocksDB 6.20.3. + .createWithDefault(5) + + val STREAMING_AGGREGATION_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.aggregation.stateFormatVersion") + .internal() + .doc("State format version used by streaming aggregation operations in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .version("2.4.0") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + + val STREAMING_STOP_ACTIVE_RUN_ON_RESTART = + buildConf("spark.sql.streaming.stopActiveRunOnRestart") + .doc("Running multiple runs of the same streaming query concurrently is not supported. " + + "If we find a concurrent active run for a streaming query (in the same or different " + + "SparkSessions on the same cluster) and this flag is true, we will stop the old streaming " + + "query run to start the new one.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val STREAMING_JOIN_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.join.stateFormatVersion") + .internal() + .doc("State format version used by streaming join operations in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .version("3.0.0") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + + val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION = + buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition") + .internal() + .doc("When true, streaming session window sorts and merge sessions in local partition " + + "prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " + + "there're lots of rows in a batch being assigned to same sessions.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + + val STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.sessionWindow.stateFormatVersion") + .internal() + .doc("State format version used by streaming session window in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .version("3.2.0") + .intConf + .checkValue(v => Set(1).contains(v), "Valid version is 1") + .createWithDefault(1) + + val UNSUPPORTED_OPERATION_CHECK_ENABLED = + buildConf("spark.sql.streaming.unsupportedOperationCheck") + .internal() + .doc("When true, the logical plan for streaming query will be checked for unsupported" + + " operations.") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val USE_DEPRECATED_KAFKA_OFFSET_FETCHING = + buildConf("spark.sql.streaming.kafka.useDeprecatedOffsetFetching") + .internal() + .doc("When true, the deprecated Consumer based offset fetching used which could cause " + + "infinite wait in Spark queries. Such cases query restart is the only workaround. " + + "For further details please see Offset Fetching chapter of Structured Streaming Kafka " + + "Integration Guide.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val STATEFUL_OPERATOR_CHECK_CORRECTNESS_ENABLED = + buildConf("spark.sql.streaming.statefulOperator.checkCorrectness.enabled") + .internal() + .doc("When true, the stateful operators for streaming query will be checked for possible " + + "correctness issue due to global watermark. The correctness issue comes from queries " + + "containing stateful operation which can emit rows older than the current watermark " + + "plus allowed late record delay, which are \"late rows\" in downstream stateful " + + "operations and these rows can be discarded. Please refer the programming guide doc for " + + "more details. Once the issue is detected, Spark will throw analysis exception. " + + "When this config is disabled, Spark will just print warning message for users. " + + "Prior to Spark 3.1.0, the behavior is disabling this config.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val FILESTREAM_SINK_METADATA_IGNORED = + buildConf("spark.sql.streaming.fileStreamSink.ignoreMetadata") + .internal() + .doc("If this is enabled, when Spark reads from the results of a streaming query written " + + "by `FileStreamSink`, Spark will ignore the metadata log and treat it as normal path to " + + "read, e.g. listing files using HDFS APIs.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + + val VARIABLE_SUBSTITUTE_ENABLED = + buildConf("spark.sql.variable.substitute") + .doc("This enables substitution using syntax like `${var}`, `${system:var}`, " + + "and `${env:var}`.") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val ENABLE_TWOLEVEL_AGG_MAP = + buildConf("spark.sql.codegen.aggregate.map.twolevel.enabled") + .internal() + .doc("Enable two-level aggregate hash map. When enabled, records will first be " + + "inserted/looked-up at a 1st-level, small, fast map, and then fallback to a " + + "2nd-level, larger, slower map when 1st level is full or keys cannot be found. " + + "When disabled, records go directly to the 2nd level.") + .version("2.3.0") + .booleanConf + .createWithDefault(true) + + val ENABLE_TWOLEVEL_AGG_MAP_PARTIAL_ONLY = + buildConf("spark.sql.codegen.aggregate.map.twolevel.partialOnly") + .internal() + .doc("Enable two-level aggregate hash map for partial aggregate only, " + + "because final aggregate might get more distinct keys compared to partial aggregate. " + + "Overhead of looking up 1st-level map might dominate when having a lot of distinct keys.") + .version("3.2.1") + .booleanConf + .createWithDefault(true) + + val ENABLE_VECTORIZED_HASH_MAP = + buildConf("spark.sql.codegen.aggregate.map.vectorized.enable") + .internal() + .doc("Enable vectorized aggregate hash map. This is for testing/benchmarking only.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val CODEGEN_SPLIT_AGGREGATE_FUNC = + buildConf("spark.sql.codegen.aggregate.splitAggregateFunc.enabled") + .internal() + .doc("When true, the code generator would split aggregate code into individual methods " + + "instead of a single big method. This can be used to avoid oversized function that " + + "can miss the opportunity of JIT optimization.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val MAX_NESTED_VIEW_DEPTH = + buildConf("spark.sql.view.maxNestedViewDepth") + .internal() + .doc("The maximum depth of a view reference in a nested view. A nested view may reference " + + "other nested views, the dependencies are organized in a directed acyclic graph (DAG). " + + "However the DAG depth may become too large and cause unexpected behavior. This " + + "configuration puts a limit on this: when the depth of a view exceeds this value during " + + "analysis, we terminate the resolution to avoid potential errors.") + .version("2.2.0") + .intConf + .checkValue(depth => depth > 0, "The maximum depth of a view reference in a nested view " + + "must be positive.") + .createWithDefault(100) + + val ALLOW_PARAMETERLESS_COUNT = + buildConf("spark.sql.legacy.allowParameterlessCount") + .internal() + .doc("When true, the SQL function 'count' is allowed to take no parameters.") + .version("3.1.1") + .booleanConf + .createWithDefault(false) + + val ALLOW_NON_EMPTY_LOCATION_IN_CTAS = + buildConf("spark.sql.legacy.allowNonEmptyLocationInCTAS") + .internal() + .doc("When false, CTAS with LOCATION throws an analysis exception if the " + + "location is not empty.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + + val ALLOW_STAR_WITH_SINGLE_TABLE_IDENTIFIER_IN_COUNT = + buildConf("spark.sql.legacy.allowStarWithSingleTableIdentifierInCount") + .internal() + .doc("When true, the SQL function 'count' is allowed to take single 'tblName.*' as parameter") + .version("3.2") + .booleanConf + .createWithDefault(false) + + val USE_CURRENT_SQL_CONFIGS_FOR_VIEW = + buildConf("spark.sql.legacy.useCurrentConfigsForView") + .internal() + .doc("When true, SQL Configs of the current active SparkSession instead of the captured " + + "ones will be applied during the parsing and analysis phases of the view resolution.") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + val STORE_ANALYZED_PLAN_FOR_VIEW = + buildConf("spark.sql.legacy.storeAnalyzedPlanForView") + .internal() + .doc("When true, analyzed plan instead of SQL text will be stored when creating " + + "temporary view") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + val ALLOW_AUTO_GENERATED_ALIAS_FOR_VEW = + buildConf("spark.sql.legacy.allowAutoGeneratedAliasForView") + .internal() + .doc("When true, it's allowed to use a input query without explicit alias when creating " + + "a permanent view.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + + val STREAMING_FILE_COMMIT_PROTOCOL_CLASS = + buildConf("spark.sql.streaming.commitProtocolClass") + .version("2.1.0") + .internal() + .stringConf + .createWithDefault("org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol") + + val STREAMING_MULTIPLE_WATERMARK_POLICY = + buildConf("spark.sql.streaming.multipleWatermarkPolicy") + .doc("Policy to calculate the global watermark value when there are multiple watermark " + + "operators in a streaming query. The default value is 'min' which chooses " + + "the minimum watermark reported across multiple operators. Other alternative value is " + + "'max' which chooses the maximum across multiple operators. " + + "Note: This configuration cannot be changed between query restarts from the same " + + "checkpoint location.") + .version("2.4.0") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValue( + str => Set("min", "max").contains(str), + "Invalid value for 'spark.sql.streaming.multipleWatermarkPolicy'. " + + "Valid values are 'min' and 'max'") + .createWithDefault("min") // must be same as MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + + val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD = + buildConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold") + .internal() + .doc("In the case of ObjectHashAggregateExec, when the size of the in-memory hash map " + + "grows too large, we will fall back to sort-based aggregation. This option sets a row " + + "count threshold for the size of the hash map.") + .version("2.2.0") + .intConf + // We are trying to be conservative and use a relatively small default count threshold here + // since the state object of some TypedImperativeAggregate function can be quite large (e.g. + // percentile_approx). + .createWithDefault(128) + + val USE_OBJECT_HASH_AGG = buildConf("spark.sql.execution.useObjectHashAggregateExec") + .internal() + .doc("Decides if we use ObjectHashAggregateExec") + .version("2.2.0") + .booleanConf + .createWithDefault(true) + + val JSON_GENERATOR_IGNORE_NULL_FIELDS = + buildConf("spark.sql.jsonGenerator.ignoreNullFields") + .doc("Whether to ignore null fields when generating JSON objects in JSON data source and " + + "JSON functions such as to_json. " + + "If false, it generates null for null fields in JSON objects.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val JSON_EXPRESSION_OPTIMIZATION = + buildConf("spark.sql.optimizer.enableJsonExpressionOptimization") + .doc("Whether to optimize JSON expressions in SQL optimizer. It includes pruning " + + "unnecessary columns from from_json, simplifying from_json + to_json, to_json + " + + "named_struct(from_json.col1, from_json.col2, ....).") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val CSV_EXPRESSION_OPTIMIZATION = + buildConf("spark.sql.optimizer.enableCsvExpressionOptimization") + .doc("Whether to optimize CSV expressions in SQL optimizer. It includes pruning " + + "unnecessary columns from from_csv.") + .version("3.2.0") + .booleanConf + .createWithDefault(true) + + val FILE_SINK_LOG_DELETION = buildConf("spark.sql.streaming.fileSink.log.deletion") + .internal() + .doc("Whether to delete the expired log files in file stream sink.") + .version("2.0.0") + .booleanConf + .createWithDefault(true) + + val FILE_SINK_LOG_COMPACT_INTERVAL = + buildConf("spark.sql.streaming.fileSink.log.compactInterval") + .internal() + .doc("Number of log files after which all the previous files " + + "are compacted into the next log file.") + .version("2.0.0") + .intConf + .createWithDefault(10) + + val FILE_SINK_LOG_CLEANUP_DELAY = + buildConf("spark.sql.streaming.fileSink.log.cleanupDelay") + .internal() + .doc("How long that a file is guaranteed to be visible for all readers.") + .version("2.0.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(TimeUnit.MINUTES.toMillis(10)) // 10 minutes + + val FILE_SOURCE_LOG_DELETION = buildConf("spark.sql.streaming.fileSource.log.deletion") + .internal() + .doc("Whether to delete the expired log files in file stream source.") + .version("2.0.1") + .booleanConf + .createWithDefault(true) + + val FILE_SOURCE_LOG_COMPACT_INTERVAL = + buildConf("spark.sql.streaming.fileSource.log.compactInterval") + .internal() + .doc("Number of log files after which all the previous files " + + "are compacted into the next log file.") + .version("2.0.1") + .intConf + .createWithDefault(10) + + val FILE_SOURCE_LOG_CLEANUP_DELAY = + buildConf("spark.sql.streaming.fileSource.log.cleanupDelay") + .internal() + .doc("How long in milliseconds a file is guaranteed to be visible for all readers.") + .version("2.0.1") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(TimeUnit.MINUTES.toMillis(10)) // 10 minutes + + val FILE_SOURCE_SCHEMA_FORCE_NULLABLE = + buildConf("spark.sql.streaming.fileSource.schema.forceNullable") + .internal() + .doc("When true, force the schema of streaming file source to be nullable (including all " + + "the fields). Otherwise, the schema might not be compatible with actual data, which " + + "leads to corruptions.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val FILE_SOURCE_CLEANER_NUM_THREADS = + buildConf("spark.sql.streaming.fileSource.cleaner.numThreads") + .doc("Number of threads used in the file source completed file cleaner.") + .version("3.0.0") + .intConf + .createWithDefault(1) + + val STREAMING_SCHEMA_INFERENCE = + buildConf("spark.sql.streaming.schemaInference") + .internal() + .doc("Whether file-based streaming sources will infer its own schema") + .version("2.0.0") + .booleanConf + .createWithDefault(false) + + val STREAMING_POLLING_DELAY = + buildConf("spark.sql.streaming.pollingDelay") + .internal() + .doc("How long to delay polling new data when no data is available") + .version("2.0.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(10L) + + val STREAMING_STOP_TIMEOUT = + buildConf("spark.sql.streaming.stopTimeout") + .doc("How long to wait in milliseconds for the streaming execution thread to stop when " + + "calling the streaming query's stop() method. 0 or negative values wait indefinitely.") + .version("3.0.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("0") + + val STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL = + buildConf("spark.sql.streaming.noDataProgressEventInterval") + .internal() + .doc("How long to wait between two progress events when there is no data") + .version("2.1.1") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(10000L) + + val STREAMING_NO_DATA_MICRO_BATCHES_ENABLED = + buildConf("spark.sql.streaming.noDataMicroBatches.enabled") + .doc( + "Whether streaming micro-batch engine will execute batches without data " + + "for eager state management for stateful streaming queries.") + .version("2.4.1") + .booleanConf + .createWithDefault(true) + + val STREAMING_METRICS_ENABLED = + buildConf("spark.sql.streaming.metricsEnabled") + .doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.") + .version("2.0.2") + .booleanConf + .createWithDefault(false) + + val STREAMING_PROGRESS_RETENTION = + buildConf("spark.sql.streaming.numRecentProgressUpdates") + .doc("The number of progress updates to retain for a streaming query") + .version("2.1.1") + .intConf + .createWithDefault(100) + + val STREAMING_CHECKPOINT_FILE_MANAGER_CLASS = + buildConf("spark.sql.streaming.checkpointFileManagerClass") + .internal() + .doc("The class used to write checkpoint files atomically. This class must be a subclass " + + "of the interface CheckpointFileManager.") + .version("2.4.0") + .stringConf + + val STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED = + buildConf("spark.sql.streaming.checkpoint.escapedPathCheck.enabled") + .internal() + .doc("Whether to detect a streaming query may pick up an incorrect checkpoint path due " + + "to SPARK-26824.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION = + buildConf("spark.sql.statistics.parallelFileListingInStatsComputation.enabled") + .internal() + .doc("When true, SQL commands use parallel file listing, " + + "as opposed to single thread listing. " + + "This usually speeds up commands that need to list many directories.") + .version("2.4.1") + .booleanConf + .createWithDefault(true) + + val DEFAULT_SIZE_IN_BYTES = buildConf("spark.sql.defaultSizeInBytes") + .internal() + .doc("The default table size used in query planning. By default, it is set to Long.MaxValue " + + s"which is larger than `${AUTO_BROADCASTJOIN_THRESHOLD.key}` to be more conservative. " + + "That is to say by default the optimizer will not choose to broadcast a table unless it " + + "knows for sure its size is small enough.") + .version("1.1.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(Long.MaxValue) + + val ENABLE_FALL_BACK_TO_HDFS_FOR_STATS = buildConf("spark.sql.statistics.fallBackToHdfs") + .doc("When true, it will fall back to HDFS if the table statistics are not available from " + + "table metadata. This is useful in determining if a table is small enough to use " + + "broadcast joins. This flag is effective only for non-partitioned Hive tables. " + + "For non-partitioned data source tables, it will be automatically recalculated if table " + + "statistics are not available. For partitioned data source and partitioned Hive tables, " + + s"It is '${DEFAULT_SIZE_IN_BYTES.key}' if table statistics are not available.") + .version("2.0.0") + .booleanConf + .createWithDefault(false) + + val NDV_MAX_ERROR = + buildConf("spark.sql.statistics.ndv.maxError") + .internal() + .doc("The maximum relative standard deviation allowed in HyperLogLog++ algorithm " + + "when generating column level statistics.") + .version("2.1.1") + .doubleConf + .createWithDefault(0.05) + + val HISTOGRAM_ENABLED = + buildConf("spark.sql.statistics.histogram.enabled") + .doc("Generates histograms when computing column statistics if enabled. Histograms can " + + "provide better estimation accuracy. Currently, Spark only supports equi-height " + + "histogram. Note that collecting histograms takes extra cost. For example, collecting " + + "column statistics usually takes only one table scan, but generating equi-height " + + "histogram will cause an extra table scan.") + .version("2.3.0") + .booleanConf + .createWithDefault(false) + + val HISTOGRAM_NUM_BINS = + buildConf("spark.sql.statistics.histogram.numBins") + .internal() + .doc("The number of bins when generating histograms.") + .version("2.3.0") + .intConf + .checkValue(num => num > 1, "The number of bins must be greater than 1.") + .createWithDefault(254) + + val PERCENTILE_ACCURACY = + buildConf("spark.sql.statistics.percentile.accuracy") + .internal() + .doc("Accuracy of percentile approximation when generating equi-height histograms. " + + "Larger value means better accuracy. The relative error can be deduced by " + + "1.0 / PERCENTILE_ACCURACY.") + .version("2.3.0") + .intConf + .createWithDefault(10000) + + val AUTO_SIZE_UPDATE_ENABLED = + buildConf("spark.sql.statistics.size.autoUpdate.enabled") + .doc("Enables automatic update for table size once table's data is changed. Note that if " + + "the total number of files of the table is very large, this can be expensive and slow " + + "down data change commands.") + .version("2.3.0") + .booleanConf + .createWithDefault(false) + + val CBO_ENABLED = + buildConf("spark.sql.cbo.enabled") + .doc("Enables CBO for estimation of plan statistics when set true.") + .version("2.2.0") + .booleanConf + .createWithDefault(false) + + val PLAN_STATS_ENABLED = + buildConf("spark.sql.cbo.planStats.enabled") + .doc("When true, the logical plan will fetch row counts and column statistics from catalog.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val JOIN_REORDER_ENABLED = + buildConf("spark.sql.cbo.joinReorder.enabled") + .doc("Enables join reorder in CBO.") + .version("2.2.0") + .booleanConf + .createWithDefault(false) + + val JOIN_REORDER_DP_THRESHOLD = + buildConf("spark.sql.cbo.joinReorder.dp.threshold") + .doc("The maximum number of joined nodes allowed in the dynamic programming algorithm.") + .version("2.2.0") + .intConf + .checkValue(number => number > 0, "The maximum number must be a positive integer.") + .createWithDefault(12) + + val JOIN_REORDER_CARD_WEIGHT = + buildConf("spark.sql.cbo.joinReorder.card.weight") + .internal() + .doc("The weight of the ratio of cardinalities (number of rows) " + + "in the cost comparison function. The ratio of sizes in bytes has weight " + + "1 - this value. The weighted geometric mean of these ratios is used to decide " + + "which of the candidate plans will be chosen by the CBO.") + .version("2.2.0") + .doubleConf + .checkValue(weight => weight >= 0 && weight <= 1, "The weight value must be in [0, 1].") + .createWithDefault(0.7) + + val JOIN_REORDER_DP_STAR_FILTER = + buildConf("spark.sql.cbo.joinReorder.dp.star.filter") + .doc("Applies star-join filter heuristics to cost based join enumeration.") + .version("2.2.0") + .booleanConf + .createWithDefault(false) + + val STARSCHEMA_DETECTION = buildConf("spark.sql.cbo.starSchemaDetection") + .doc("When true, it enables join reordering based on star schema detection. ") + .version("2.2.0") + .booleanConf + .createWithDefault(false) + + val STARSCHEMA_FACT_TABLE_RATIO = buildConf("spark.sql.cbo.starJoinFTRatio") + .internal() + .doc("Specifies the upper limit of the ratio between the largest fact tables" + + " for a star join to be considered. ") + .version("2.2.0") + .doubleConf + .createWithDefault(0.9) + + private def isValidTimezone(zone: String): Boolean = { + Try { DateTimeUtils.getZoneId(zone) }.isSuccess + } + + val SESSION_LOCAL_TIMEZONE = buildConf("spark.sql.session.timeZone") + .doc("The ID of session local timezone in the format of either region-based zone IDs or " + + "zone offsets. Region IDs must have the form 'area/city', such as 'America/Los_Angeles'. " + + "Zone offsets must be in the format '(+|-)HH', '(+|-)HH:mm' or '(+|-)HH:mm:ss', e.g '-08', " + + "'+01:00' or '-13:33:33'. Also 'UTC' and 'Z' are supported as aliases of '+00:00'. Other " + + "short names are not recommended to use because they can be ambiguous.") + .version("2.2.0") + .stringConf + .checkValue(isValidTimezone, s"Cannot resolve the given timezone with" + + " ZoneId.of(_, ZoneId.SHORT_IDS)") + .createWithDefaultFunction(() => TimeZone.getDefault.getID) + + val WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.windowExec.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of rows guaranteed to be held in memory by the window operator") + .version("2.2.1") + .intConf + .createWithDefault(4096) + + val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.windowExec.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows to be spilled by window operator") + .version("2.2.0") + .intConf + .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + + val SESSION_WINDOW_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.sessionWindow.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of windows guaranteed to be held in memory by the " + + "session window operator. Note that the buffer is used only for the query Spark " + + "cannot apply aggregations on determining session window.") + .version("3.2.0") + .intConf + .createWithDefault(4096) + + val SESSION_WINDOW_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.sessionWindow.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows to be spilled by window operator. Note that " + + "the buffer is used only for the query Spark cannot apply aggregations on determining " + + "session window.") + .version("3.2.0") + .intConf + .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + + val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of rows guaranteed to be held in memory by the sort merge " + + "join operator") + .version("2.2.1") + .intConf + .createWithDefault(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) + + val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows to be spilled by sort merge join operator") + .version("2.2.0") + .intConf + .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + + val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of rows guaranteed to be held in memory by the cartesian " + + "product operator") + .version("2.2.1") + .intConf + .createWithDefault(4096) + + val CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD = + buildConf("spark.sql.cartesianProductExec.buffer.spill.threshold") + .internal() + .doc("Threshold for number of rows to be spilled by cartesian product operator") + .version("2.2.0") + .intConf + .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + + val SUPPORT_QUOTED_REGEX_COLUMN_NAME = buildConf("spark.sql.parser.quotedRegexColumnNames") + .doc("When true, quoted Identifiers (using backticks) in SELECT statement are interpreted" + + " as regular expressions.") + .version("2.3.0") + .booleanConf + .createWithDefault(false) + + val RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION = + buildConf("spark.sql.execution.rangeExchange.sampleSizePerPartition") + .internal() + .doc("Number of points to sample per partition in order to determine the range boundaries" + + " for range partitioning, typically used in global sorting (without limit).") + .version("2.3.0") + .intConf + .createWithDefault(100) + + val ARROW_EXECUTION_ENABLED = + buildConf("spark.sql.execution.arrow.enabled") + .doc("(Deprecated since Spark 3.0, please set 'spark.sql.execution.arrow.pyspark.enabled'.)") + .version("2.3.0") + .booleanConf + .createWithDefault(false) + + val ARROW_PYSPARK_EXECUTION_ENABLED = + buildConf("spark.sql.execution.arrow.pyspark.enabled") + .doc("When true, make use of Apache Arrow for columnar data transfers in PySpark. " + + "This optimization applies to: " + + "1. pyspark.sql.DataFrame.toPandas " + + "2. pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame " + + "The following data types are unsupported: " + + "ArrayType of TimestampType, and nested StructType.") + .version("3.0.0") + .fallbackConf(ARROW_EXECUTION_ENABLED) + + val ARROW_PYSPARK_SELF_DESTRUCT_ENABLED = + buildConf("spark.sql.execution.arrow.pyspark.selfDestruct.enabled") + .doc("(Experimental) When true, make use of Apache Arrow's self-destruct and split-blocks " + + "options for columnar data transfers in PySpark, when converting from Arrow to Pandas. " + + "This reduces memory usage at the cost of some CPU time. " + + "This optimization applies to: pyspark.sql.DataFrame.toPandas " + + "when 'spark.sql.execution.arrow.pyspark.enabled' is set.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + + val PYSPARK_JVM_STACKTRACE_ENABLED = + buildConf("spark.sql.pyspark.jvmStacktrace.enabled") + .doc("When true, it shows the JVM stacktrace in the user-facing PySpark exception " + + "together with Python stacktrace. By default, it is disabled and hides JVM stacktrace " + + "and shows a Python-friendly exception only.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val ARROW_SPARKR_EXECUTION_ENABLED = + buildConf("spark.sql.execution.arrow.sparkr.enabled") + .doc("When true, make use of Apache Arrow for columnar data transfers in SparkR. " + + "This optimization applies to: " + + "1. createDataFrame when its input is an R DataFrame " + + "2. collect " + + "3. dapply " + + "4. gapply " + + "The following data types are unsupported: " + + "FloatType, BinaryType, ArrayType, StructType and MapType.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val ARROW_FALLBACK_ENABLED = + buildConf("spark.sql.execution.arrow.fallback.enabled") + .doc("(Deprecated since Spark 3.0, please set " + + "'spark.sql.execution.arrow.pyspark.fallback.enabled'.)") + .version("2.4.0") + .booleanConf + .createWithDefault(true) + + val ARROW_PYSPARK_FALLBACK_ENABLED = + buildConf("spark.sql.execution.arrow.pyspark.fallback.enabled") + .doc(s"When true, optimizations enabled by '${ARROW_PYSPARK_EXECUTION_ENABLED.key}' will " + + "fallback automatically to non-optimized implementations if an error occurs.") + .version("3.0.0") + .fallbackConf(ARROW_FALLBACK_ENABLED) + + val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = + buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") + .doc("When using Apache Arrow, limit the maximum number of records that can be written " + + "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") + .version("2.3.0") + .intConf + .createWithDefault(10000) + + val PANDAS_UDF_BUFFER_SIZE = + buildConf("spark.sql.execution.pandas.udf.buffer.size") + .doc( + s"Same as `${BUFFER_SIZE.key}` but only applies to Pandas UDF executions. If it is not " + + s"set, the fallback is `${BUFFER_SIZE.key}`. Note that Pandas execution requires more " + + "than 4 bytes. Lowering this value could make small Pandas UDF batch iterated and " + + "pipelined; however, it might degrade performance. See SPARK-27870.") + .version("3.0.0") + .fallbackConf(BUFFER_SIZE) + + val PYSPARK_SIMPLIFIEID_TRACEBACK = + buildConf("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled") + .doc( + "When true, the traceback from Python UDFs is simplified. It hides " + + "the Python worker, (de)serialization, etc from PySpark in tracebacks, and only " + + "shows the exception messages from UDFs. Note that this works only with CPython 3.7+.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME = + buildConf("spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName") + .internal() + .doc("When true, columns will be looked up by name if labeled with a string and fallback " + + "to use position if not. When false, a grouped map Pandas UDF will assign columns from " + + "the returned Pandas DataFrame based on position, regardless of column label type. " + + "This configuration will be deprecated in future releases.") + .version("2.4.1") + .booleanConf + .createWithDefault(true) + + val PANDAS_ARROW_SAFE_TYPE_CONVERSION = + buildConf("spark.sql.execution.pandas.convertToArrowArraySafely") + .internal() + .doc("When true, Arrow will perform safe type conversion when converting " + + "Pandas.Series to Arrow array during serialization. Arrow will raise errors " + + "when detecting unsafe type conversion like overflow. When false, disabling Arrow's type " + + "check and do type conversions anyway. This config only works for Arrow 0.11.0+.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") + .internal() + .doc("When true, the apply function of the rule verifies whether the right node of the" + + " except operation is of type Filter or Project followed by Filter. If yes, the rule" + + " further verifies 1) Excluding the filter operations from the right (as well as the" + + " left node, if any) on the top, whether both the nodes evaluates to a same result." + + " 2) The left and right nodes don't contain any SubqueryExpressions. 3) The output" + + " column names of the left node are distinct. If all the conditions are met, the" + + " rule will replace the except operation with a Filter by flipping the filter" + + " condition(s) of the right node.") + .version("2.3.0") + .booleanConf + .createWithDefault(true) + + val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = + buildConf("spark.sql.decimalOperations.allowPrecisionLoss") + .internal() + .doc("When true (default), establishing the result type of an arithmetic operation " + + "happens according to Hive behavior and SQL ANSI 2011 specification, i.e. rounding the " + + "decimal part of the result if an exact representation is not possible. Otherwise, NULL " + + "is returned in those cases, as previously.") + .version("2.3.1") + .booleanConf + .createWithDefault(true) + + val LITERAL_PICK_MINIMUM_PRECISION = + buildConf("spark.sql.legacy.literal.pickMinimumPrecision") + .internal() + .doc("When integral literal is used in decimal operations, pick a minimum precision " + + "required by the literal if this config is true, to make the resulting precision and/or " + + "scale smaller. This can reduce the possibility of precision lose and/or overflow.") + .version("2.3.3") + .booleanConf + .createWithDefault(true) + + val SQL_OPTIONS_REDACTION_PATTERN = buildConf("spark.sql.redaction.options.regex") + .doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " + + "information. The values of options whose names that match this regex will be redacted " + + "in the explain output. This redaction is applied on top of the global redaction " + + s"configuration defined by ${SECRET_REDACTION_PATTERN.key}.") + .version("2.2.2") + .regexConf + .createWithDefault("(?i)url".r) + + val SQL_STRING_REDACTION_PATTERN = + buildConf("spark.sql.redaction.string.regex") + .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + + "information. When this regex matches a string part, that string part is replaced by a " + + "dummy value. This is currently used to redact the output of SQL explain commands. " + + "When this conf is not set, the value from `spark.redaction.string.regex` is used.") + .version("2.3.0") + .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN) + + val CONCAT_BINARY_AS_STRING = buildConf("spark.sql.function.concatBinaryAsString") + .doc("When this option is set to false and all inputs are binary, `functions.concat` returns " + + "an output as binary. Otherwise, it returns as a string.") + .version("2.3.0") + .booleanConf + .createWithDefault(false) + + val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString") + .doc("When this option is set to false and all inputs are binary, `elt` returns " + + "an output as binary. Otherwise, it returns as a string.") + .version("2.3.0") + .booleanConf + .createWithDefault(false) + + val VALIDATE_PARTITION_COLUMNS = + buildConf("spark.sql.sources.validatePartitionColumns") + .internal() + .doc("When this option is set to true, partition column values will be validated with " + + "user-specified schema. If the validation fails, a runtime exception is thrown. " + + "When this option is set to false, the partition column value will be converted to null " + + "if it can not be casted to corresponding user-specified schema.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE = + buildConf("spark.sql.streaming.continuous.epochBacklogQueueSize") + .doc("The max number of entries to be stored in queue to wait for late epochs. " + + "If this parameter is exceeded by the size of the queue, stream will stop with an error.") + .version("3.0.0") + .intConf + .createWithDefault(10000) + + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = + buildConf("spark.sql.streaming.continuous.executorQueueSize") + .internal() + .doc("The size (measured in number of rows) of the queue used in continuous execution to" + + " buffer the results of a ContinuousDataReader.") + .version("2.3.0") + .intConf + .createWithDefault(1024) + + val CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS = + buildConf("spark.sql.streaming.continuous.executorPollIntervalMs") + .internal() + .doc("The interval at which continuous execution readers will poll to check whether" + + " the epoch has advanced on the driver.") + .version("2.3.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(100) + + val USE_V1_SOURCE_LIST = buildConf("spark.sql.sources.useV1SourceList") + .internal() + .doc("A comma-separated list of data source short names or fully qualified data source " + + "implementation class names for which Data Source V2 code path is disabled. These data " + + "sources will fallback to Data Source V1 code path.") + .version("3.0.0") + .stringConf + .createWithDefault("avro,csv,json,kafka,orc,parquet,text") + + val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") + .doc("A comma-separated list of fully qualified data source register class names for which" + + " StreamWriteSupport is disabled. Writes to these sources will fall back to the V1 Sinks.") + .version("2.3.1") + .stringConf + .createWithDefault("") + + val DISABLED_V2_STREAMING_MICROBATCH_READERS = + buildConf("spark.sql.streaming.disabledV2MicroBatchReaders") + .internal() + .doc( + "A comma-separated list of fully qualified data source register class names for which " + + "MicroBatchReadSupport is disabled. Reads from these sources will fall back to the " + + "V1 Sources.") + .version("2.4.0") + .stringConf + .createWithDefault("") + + val FASTFAIL_ON_FILEFORMAT_OUTPUT = + buildConf("spark.sql.execution.fastFailOnFileFormatOutput") + .internal() + .doc("Whether to fast fail task execution when writing output to FileFormat datasource. " + + "If this is enabled, in `FileFormatWriter` we will catch `FileAlreadyExistsException` " + + "and fast fail output task without further task retry. Only enabling this if you know " + + "the `FileAlreadyExistsException` of the output task is unrecoverable, i.e., further " + + "task attempts won't be able to success. If the `FileAlreadyExistsException` might be " + + "recoverable, you should keep this as disabled and let Spark to retry output tasks. " + + "This is disabled by default.") + .version("3.0.2") + .booleanConf + .createWithDefault(false) + + object PartitionOverwriteMode extends Enumeration { + val STATIC, DYNAMIC = Value + } + + val PARTITION_OVERWRITE_MODE = + buildConf("spark.sql.sources.partitionOverwriteMode") + .doc("When INSERT OVERWRITE a partitioned data source table, we currently support 2 modes: " + + "static and dynamic. In static mode, Spark deletes all the partitions that match the " + + "partition specification(e.g. PARTITION(a=1,b)) in the INSERT statement, before " + + "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " + + "those partitions that have data written into it at runtime. By default we use static " + + "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + + "affect Hive serde tables, as they are always overwritten with dynamic mode. This can " + + "also be set as an output option for a data source using key partitionOverwriteMode " + + "(which takes precedence over this setting), e.g. " + + "dataframe.write.option(\"partitionOverwriteMode\", \"dynamic\").save(path)." + ) + .version("2.3.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(PartitionOverwriteMode.values.map(_.toString)) + .createWithDefault(PartitionOverwriteMode.STATIC.toString) + + object StoreAssignmentPolicy extends Enumeration { + val ANSI, LEGACY, STRICT = Value + } + + val STORE_ASSIGNMENT_POLICY = + buildConf("spark.sql.storeAssignmentPolicy") + .doc("When inserting a value into a column with different data type, Spark will perform " + + "type coercion. Currently, we support 3 policies for the type coercion rules: ANSI, " + + "legacy and strict. With ANSI policy, Spark performs the type coercion as per ANSI SQL. " + + "In practice, the behavior is mostly the same as PostgreSQL. " + + "It disallows certain unreasonable type conversions such as converting " + + "`string` to `int` or `double` to `boolean`. " + + "With legacy policy, Spark allows the type coercion as long as it is a valid `Cast`, " + + "which is very loose. e.g. converting `string` to `int` or `double` to `boolean` is " + + "allowed. It is also the only behavior in Spark 2.x and it is compatible with Hive. " + + "With strict policy, Spark doesn't allow any possible precision loss or data truncation " + + "in type coercion, e.g. converting `double` to `int` or `decimal` to `double` is " + + "not allowed." + ) + .version("3.0.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(StoreAssignmentPolicy.values.map(_.toString)) + .createWithDefault(StoreAssignmentPolicy.ANSI.toString) + + val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled") + .doc("When true, Spark SQL uses an ANSI compliant dialect instead of being Hive compliant. " + + "For example, Spark will throw an exception at runtime instead of returning null results " + + "when the inputs to a SQL operator/function are invalid." + + "For full details of this dialect, you can find them in the section \"ANSI Compliance\" of " + + "Spark's documentation. Some ANSI dialect features may be not from the ANSI SQL " + + "standard directly, but their behaviors align with ANSI SQL's style") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val SORT_BEFORE_REPARTITION = + buildConf("spark.sql.execution.sortBeforeRepartition") + .internal() + .doc("When perform a repartition following a shuffle, the output row ordering would be " + + "nondeterministic. If some downstream stages fail and some tasks of the repartition " + + "stage retry, these tasks may generate different data, and that can lead to correctness " + + "issues. Turn on this config to insert a local sort before actually doing repartition " + + "to generate consistent repartition results. The performance of repartition() may go " + + "down since we insert extra local sort before it.") + .version("2.1.4") + .booleanConf + .createWithDefault(true) + + val NESTED_SCHEMA_PRUNING_ENABLED = + buildConf("spark.sql.optimizer.nestedSchemaPruning.enabled") + .internal() + .doc("Prune nested fields from a logical relation's output which are unnecessary in " + + "satisfying a query. This optimization allows columnar file format readers to avoid " + + "reading unnecessary nested column data. Currently Parquet and ORC are the " + + "data sources that implement this optimization.") + .version("2.4.1") + .booleanConf + .createWithDefault(true) + + val DISABLE_HINTS = + buildConf("spark.sql.optimizer.disableHints") + .internal() + .doc("When true, the optimizer will disable user-specified hints that are additional " + + "directives for better planning of a query.") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + val NESTED_PREDICATE_PUSHDOWN_FILE_SOURCE_LIST = + buildConf("spark.sql.optimizer.nestedPredicatePushdown.supportedFileSources") + .internal() + .doc("A comma-separated list of data source short names or fully qualified data source " + + "implementation class names for which Spark tries to push down predicates for nested " + + "columns and/or names containing `dots` to data sources. This configuration is only " + + "effective with file-based data sources in DSv1. Currently, Parquet and ORC implement " + + "both optimizations. The other data sources don't support this feature yet. So the " + + "default value is 'parquet,orc'.") + .version("3.0.0") + .stringConf + .createWithDefault("parquet,orc") + + val SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED = + buildConf("spark.sql.optimizer.serializer.nestedSchemaPruning.enabled") + .internal() + .doc("Prune nested fields from object serialization operator which are unnecessary in " + + "satisfying a query. This optimization allows object serializers to avoid " + + "executing unnecessary nested expressions.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val NESTED_PRUNING_ON_EXPRESSIONS = + buildConf("spark.sql.optimizer.expression.nestedPruning.enabled") + .internal() + .doc("Prune nested fields from expressions in an operator which are unnecessary in " + + "satisfying a query. Note that this optimization doesn't prune nested fields from " + + "physical data source scanning. For pruning nested fields from scanning, please use " + + "`spark.sql.optimizer.nestedSchemaPruning.enabled` config.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val DECORRELATE_INNER_QUERY_ENABLED = + buildConf("spark.sql.optimizer.decorrelateInnerQuery.enabled") + .internal() + .doc("Decorrelate inner query by eliminating correlated references and build domain joins.") + .version("3.2.0") + .booleanConf + .createWithDefault(true) + + val OPTIMIZE_ONE_ROW_RELATION_SUBQUERY = + buildConf("spark.sql.optimizer.optimizeOneRowRelationSubquery") + .internal() + .doc("When true, the optimizer will inline subqueries with OneRowRelation as leaf nodes.") + .version("3.2.0") + .booleanConf + .createWithDefault(true) + + val TOP_K_SORT_FALLBACK_THRESHOLD = + buildConf("spark.sql.execution.topKSortFallbackThreshold") + .doc("In SQL queries with a SORT followed by a LIMIT like " + + "'SELECT x FROM t ORDER BY y LIMIT m', if m is under this threshold, do a top-K sort" + + " in memory, otherwise do a global sort which spills to disk if necessary.") + .version("2.4.0") + .intConf + .createWithDefault(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) + + object Deprecated { + val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + } + + object Replaced { + val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" + } + + val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") + .internal() + .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + + "Other column values can be ignored during parsing even if they are malformed.") + .version("2.4.0") + .booleanConf + .createWithDefault(true) + + val CSV_INPUT_BUFFER_SIZE = buildConf("spark.sql.csv.parser.inputBufferSize") + .internal() + .doc("If it is set, it configures the buffer size of CSV input during parsing. " + + "It is the same as inputBufferSize option in CSV which has a higher priority. " + + "Note that this is a workaround for the parsing library's regression, and this " + + "configuration is internal and supposed to be removed in the near future.") + .version("3.0.3") + .intConf + .createOptional + + val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled") + .doc("Enables eager evaluation or not. When true, the top K rows of Dataset will be " + + "displayed if and only if the REPL supports the eager evaluation. Currently, the " + + "eager evaluation is supported in PySpark and SparkR. In PySpark, for the notebooks like " + + "Jupyter, the HTML table (generated by _repr_html_) will be returned. For plain Python " + + "REPL, the returned outputs are formatted like dataframe.show(). In SparkR, the returned " + + "outputs are showed similar to R data.frame would.") + .version("2.4.0") + .booleanConf + .createWithDefault(false) + + val REPL_EAGER_EVAL_MAX_NUM_ROWS = buildConf("spark.sql.repl.eagerEval.maxNumRows") + .doc("The max number of rows that are returned by eager evaluation. This only takes " + + s"effect when ${REPL_EAGER_EVAL_ENABLED.key} is set to true. The valid range of this " + + "config is from 0 to (Int.MaxValue - 1), so the invalid config like negative and " + + "greater than (Int.MaxValue - 1) will be normalized to 0 and (Int.MaxValue - 1).") + .version("2.4.0") + .intConf + .createWithDefault(20) + + val REPL_EAGER_EVAL_TRUNCATE = buildConf("spark.sql.repl.eagerEval.truncate") + .doc("The max number of characters for each cell that is returned by eager evaluation. " + + s"This only takes effect when ${REPL_EAGER_EVAL_ENABLED.key} is set to true.") + .version("2.4.0") + .intConf + .createWithDefault(20) + + val FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT = + buildConf("spark.sql.codegen.aggregate.fastHashMap.capacityBit") + .internal() + .doc("Capacity for the max number of rows to be held in memory " + + "by the fast hash aggregate product operator. The bit is not for actual value, " + + "but the actual numBuckets is determined by loadFactor " + + "(e.g: default bit value 16 , the actual numBuckets is ((1 << 16) / 0.5).") + .version("2.4.0") + .intConf + .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].") + .createWithDefault(16) + + val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") + .doc("Compression codec used in writing of AVRO files. Supported codecs: " + + "uncompressed, deflate, snappy, bzip2, xz and zstandard. Default codec is snappy.") + .version("2.4.0") + .stringConf + .checkValues(Set("uncompressed", "deflate", "snappy", "bzip2", "xz", "zstandard")) + .createWithDefault("snappy") + + val AVRO_DEFLATE_LEVEL = buildConf("spark.sql.avro.deflate.level") + .doc("Compression level for the deflate codec used in writing of AVRO files. " + + "Valid value must be in the range of from 1 to 9 inclusive or -1. " + + "The default value is -1 which corresponds to 6 level in the current implementation.") + .version("2.4.0") + .intConf + .checkValues((1 to 9).toSet + Deflater.DEFAULT_COMPRESSION) + .createWithDefault(Deflater.DEFAULT_COMPRESSION) + + val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull") + .internal() + .doc(s"If it is set to false, or ${ANSI_ENABLED.key} is true, then size of null returns " + + "null. Otherwise, it returns -1, which was inherited from Hive.") + .version("2.4.0") + .booleanConf + .createWithDefault(true) + + val LEGACY_PARSE_NULL_PARTITION_SPEC_AS_STRING_LITERAL = + buildConf("spark.sql.legacy.parseNullPartitionSpecAsStringLiteral") + .internal() + .doc("If it is set to true, `PARTITION(col=null)` is parsed as a string literal of its " + + "text representation, e.g., string 'null', when the partition column is string type. " + + "Otherwise, it is always parsed as a null literal in the partition spec.") + .version("3.0.2") + .booleanConf + .createWithDefault(false) + + val LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED = + buildConf("spark.sql.legacy.replaceDatabricksSparkAvro.enabled") + .internal() + .doc("If it is set to true, the data source provider com.databricks.spark.avro is mapped " + + "to the built-in but external Avro data source module for backward compatibility.") + .version("2.4.0") + .booleanConf + .createWithDefault(true) + + val LEGACY_SETOPS_PRECEDENCE_ENABLED = + buildConf("spark.sql.legacy.setopsPrecedence.enabled") + .internal() + .doc("When set to true and the order of evaluation is not specified by parentheses, the " + + "set operations are performed from left to right as they appear in the query. When set " + + "to false and order of evaluation is not specified by parentheses, INTERSECT operations " + + "are performed before any UNION, EXCEPT and MINUS operations.") + .version("2.4.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_EXPONENT_LITERAL_AS_DECIMAL_ENABLED = + buildConf("spark.sql.legacy.exponentLiteralAsDecimal.enabled") + .internal() + .doc("When set to true, a literal with an exponent (e.g. 1E-30) would be parsed " + + "as Decimal rather than Double.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED = + buildConf("spark.sql.legacy.allowNegativeScaleOfDecimal") + .internal() + .doc("When set to true, negative scale of Decimal type is allowed. For example, " + + "the type of number 1E10BD under legacy mode is DecimalType(2, -9), but is " + + "Decimal(11, 0) in non legacy mode.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_BUCKETED_TABLE_SCAN_OUTPUT_ORDERING = + buildConf("spark.sql.legacy.bucketedTableScan.outputOrdering") + .internal() + .doc("When true, the bucketed table scan will list files during planning to figure out the " + + "output ordering, which is expensive and may make the planning quite slow.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE = + buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere") + .internal() + .doc("If it is set to true, the parser will treat HAVING without GROUP BY as a normal " + + "WHERE, which does not follow SQL standard.") + .version("2.4.1") + .booleanConf + .createWithDefault(false) + + val LEGACY_ALLOW_EMPTY_STRING_IN_JSON = + buildConf("spark.sql.legacy.json.allowEmptyString.enabled") + .internal() + .doc("When set to true, the parser of JSON data source treats empty strings as null for " + + "some data types such as `IntegerType`.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE = + buildConf("spark.sql.legacy.createEmptyCollectionUsingStringType") + .internal() + .doc("When set to true, Spark returns an empty collection with `StringType` as element " + + "type if the `array`/`map` function is called without any parameters. Otherwise, Spark " + + "returns an empty collection with `NullType` as element type.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_ALLOW_UNTYPED_SCALA_UDF = + buildConf("spark.sql.legacy.allowUntypedScalaUDF") + .internal() + .doc("When set to true, user is allowed to use org.apache.spark.sql.functions." + + "udf(f: AnyRef, dataType: DataType). Otherwise, an exception will be thrown at runtime.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_STATISTICAL_AGGREGATE = + buildConf("spark.sql.legacy.statisticalAggregate") + .internal() + .doc("When set to true, statistical aggregate function returns Double.NaN " + + "if divide by zero occurred during expression evaluation, otherwise, it returns null. " + + "Before version 3.1.0, it returns NaN in divideByZero case by default.") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + val TRUNCATE_TABLE_IGNORE_PERMISSION_ACL = + buildConf("spark.sql.truncateTable.ignorePermissionAcl.enabled") + .internal() + .doc("When set to true, TRUNCATE TABLE command will not try to set back original " + + "permission and ACLs when re-creating the table/partition paths.") + .version("2.4.6") + .booleanConf + .createWithDefault(false) + + val NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE = + buildConf("spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue") + .internal() + .doc("When set to true, the key attribute resulted from running `Dataset.groupByKey` " + + "for non-struct key type, will be named as `value`, following the behavior of Spark " + + "version 2.4 and earlier.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val MAX_TO_STRING_FIELDS = buildConf("spark.sql.debug.maxToStringFields") + .doc("Maximum number of fields of sequence-like entries can be converted to strings " + + "in debug output. Any elements beyond the limit will be dropped and replaced by a" + + """ "... N more fields" placeholder.""") + .version("3.0.0") + .intConf + .createWithDefault(25) + + val MAX_PLAN_STRING_LENGTH = buildConf("spark.sql.maxPlanStringLength") + .doc("Maximum number of characters to output for a plan string. If the plan is " + + "longer, further output will be truncated. The default setting always generates a full " + + "plan. Set this to a lower value such as 8k if plan strings are taking up too much " + + "memory or are causing OutOfMemory errors in the driver or UI processes.") + .version("3.0.0") + .bytesConf(ByteUnit.BYTE) + .checkValue(i => i >= 0 && i <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH, "Invalid " + + "value for 'spark.sql.maxPlanStringLength'. Length must be a valid string length " + + "(nonnegative and shorter than the maximum size).") + .createWithDefaultString(s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}") + + val MAX_METADATA_STRING_LENGTH = buildConf("spark.sql.maxMetadataStringLength") + .doc("Maximum number of characters to output for a metadata string. e.g. " + + "file location in `DataSourceScanExec`, every value will be abbreviated if exceed length.") + .version("3.1.0") + .intConf + .checkValue(_ > 3, "This value must be bigger than 3.") + .createWithDefault(100) + + val SET_COMMAND_REJECTS_SPARK_CORE_CONFS = + buildConf("spark.sql.legacy.setCommandRejectsSparkCoreConfs") + .internal() + .doc("If it is set to true, SET command will fail when the key is registered as " + + "a SparkConf entry.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + object TimestampTypes extends Enumeration { + val TIMESTAMP_NTZ, TIMESTAMP_LTZ = Value + } + + val TIMESTAMP_TYPE = + buildConf("spark.sql.timestampType") + .doc("Configures the default timestamp type of Spark SQL, including SQL DDL, Cast clause " + + s"and type literal. Setting the configuration as ${TimestampTypes.TIMESTAMP_NTZ} will " + + "use TIMESTAMP WITHOUT TIME ZONE as the default type while putting it as " + + s"${TimestampTypes.TIMESTAMP_LTZ} will use TIMESTAMP WITH LOCAL TIME ZONE. " + + "Before the 3.3.0 release, Spark only supports the TIMESTAMP WITH " + + "LOCAL TIME ZONE type.") + .version("3.3.0") + .internal() + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(TimestampTypes.values.map(_.toString)) + .createWithDefault(TimestampTypes.TIMESTAMP_LTZ.toString) + + val DATETIME_JAVA8API_ENABLED = buildConf("spark.sql.datetime.java8API.enabled") + .doc("If the configuration property is set to true, java.time.Instant and " + + "java.time.LocalDate classes of Java 8 API are used as external types for " + + "Catalyst's TimestampType and DateType. If it is set to false, java.sql.Timestamp " + + "and java.sql.Date are used for the same purpose.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val UI_EXPLAIN_MODE = buildConf("spark.sql.ui.explainMode") + .doc("Configures the query explain mode used in the Spark SQL UI. The value can be 'simple', " + + "'extended', 'codegen', 'cost', or 'formatted'. The default value is 'formatted'.") + .version("3.1.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValue(mode => Set("SIMPLE", "EXTENDED", "CODEGEN", "COST", "FORMATTED").contains(mode), + "Invalid value for 'spark.sql.ui.explainMode'. Valid values are 'simple', 'extended', " + + "'codegen', 'cost' and 'formatted'.") + .createWithDefault("formatted") + + val SOURCES_BINARY_FILE_MAX_LENGTH = buildConf("spark.sql.sources.binaryFile.maxLength") + .doc("The max length of a file that can be read by the binary file data source. " + + "Spark will fail fast and not attempt to read the file if its length exceeds this value. " + + "The theoretical max is Int.MaxValue, though VMs might implement a smaller max.") + .version("3.0.0") + .internal() + .intConf + .createWithDefault(Int.MaxValue) + + val LEGACY_CAST_DATETIME_TO_STRING = + buildConf("spark.sql.legacy.typeCoercion.datetimeToString.enabled") + .internal() + .doc("If it is set to true, date/timestamp will cast to string in binary comparisons " + + s"with String when ${ANSI_ENABLED.key} is false.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val DEFAULT_CATALOG = buildConf("spark.sql.defaultCatalog") + .doc("Name of the default catalog. This will be the current catalog if users have not " + + "explicitly set the current catalog yet.") + .version("3.0.0") + .stringConf + .createWithDefault(SESSION_CATALOG_NAME) + + val V2_SESSION_CATALOG_IMPLEMENTATION = + buildConf(s"spark.sql.catalog.$SESSION_CATALOG_NAME") + .doc("A catalog implementation that will be used as the v2 interface to Spark's built-in " + + s"v1 catalog: $SESSION_CATALOG_NAME. This catalog shares its identifier namespace with " + + s"the $SESSION_CATALOG_NAME and must be consistent with it; for example, if a table can " + + s"be loaded by the $SESSION_CATALOG_NAME, this catalog must also return the table " + + s"metadata. To delegate operations to the $SESSION_CATALOG_NAME, implementations can " + + "extend 'CatalogExtension'.") + .version("3.0.0") + .stringConf + .createOptional + + object MapKeyDedupPolicy extends Enumeration { + val EXCEPTION, LAST_WIN = Value + } + + val MAP_KEY_DEDUP_POLICY = buildConf("spark.sql.mapKeyDedupPolicy") + .doc("The policy to deduplicate map keys in builtin function: CreateMap, MapFromArrays, " + + "MapFromEntries, StringToMap, MapConcat and TransformKeys. When EXCEPTION, the query " + + "fails if duplicated map keys are detected. When LAST_WIN, the map key that is inserted " + + "at last takes precedence.") + .version("3.0.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(MapKeyDedupPolicy.values.map(_.toString)) + .createWithDefault(MapKeyDedupPolicy.EXCEPTION.toString) + + val LEGACY_LOOSE_UPCAST = buildConf("spark.sql.legacy.doLooseUpcast") + .internal() + .doc("When true, the upcast will be loose and allows string to atomic types.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + object LegacyBehaviorPolicy extends Enumeration { + val EXCEPTION, LEGACY, CORRECTED = Value + } + + val LEGACY_CTE_PRECEDENCE_POLICY = buildConf("spark.sql.legacy.ctePrecedencePolicy") + .internal() + .doc("When LEGACY, outer CTE definitions takes precedence over inner definitions. If set to " + + "CORRECTED, inner CTE definitions take precedence. The default value is EXCEPTION, " + + "AnalysisException is thrown while name conflict is detected in nested CTE. This config " + + "will be removed in future versions and CORRECTED will be the only behavior.") + .version("3.0.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) + + val LEGACY_TIME_PARSER_POLICY = buildConf("spark.sql.legacy.timeParserPolicy") + .internal() + .doc("When LEGACY, java.text.SimpleDateFormat is used for formatting and parsing " + + "dates/timestamps in a locale-sensitive manner, which is the approach before Spark 3.0. " + + "When set to CORRECTED, classes from java.time.* packages are used for the same purpose. " + + "The default value is EXCEPTION, RuntimeException is thrown when we will get different " + + "results.") + .version("3.0.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) + + val LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC = + buildConf("spark.sql.legacy.followThreeValuedLogicInArrayExists") + .internal() + .doc("When true, the ArrayExists will follow the three-valued boolean logic.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val ADDITIONAL_REMOTE_REPOSITORIES = + buildConf("spark.sql.maven.additionalRemoteRepositories") + .doc("A comma-delimited string config of the optional additional remote Maven mirror " + + "repositories. This is only used for downloading Hive jars in IsolatedClientLoader " + + "if the default Maven Central repo is unreachable.") + .version("3.0.0") + .stringConf + .createWithDefault( + sys.env.getOrElse("DEFAULT_ARTIFACT_REPOSITORY", + "https://maven-central.storage-download.googleapis.com/maven2/")) + + val LEGACY_FROM_DAYTIME_STRING = + buildConf("spark.sql.legacy.fromDayTimeString.enabled") + .internal() + .doc("When true, the `from` bound is not taken into account in conversion of " + + "a day-time string to an interval, and the `to` bound is used to skip " + + "all interval units out of the specified range. If it is set to `false`, " + + "`ParseException` is thrown if the input does not match to the pattern " + + "defined by `from` and `to`.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_PROPERTY_NON_RESERVED = + buildConf("spark.sql.legacy.notReserveProperties") + .internal() + .doc("When true, all database and table properties are not reserved and available for " + + "create/alter syntaxes. But please be aware that the reserved properties will be " + + "silently removed.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_ADD_SINGLE_FILE_IN_ADD_FILE = + buildConf("spark.sql.legacy.addSingleFileInAddFile") + .internal() + .doc("When true, only a single file can be added using ADD FILE. If false, then users " + + "can add directory by passing directory path to ADD FILE.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED = + buildConf("spark.sql.legacy.mssqlserver.numericMapping.enabled") + .internal() + .doc("When true, use legacy MySqlServer SMALLINT and REAL type mapping.") + .version("2.4.5") + .booleanConf + .createWithDefault(false) + + val CSV_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.csv.filterPushdown.enabled") + .doc("When true, enable filter pushdown to CSV datasource.") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + val JSON_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.json.filterPushdown.enabled") + .doc("When true, enable filter pushdown to JSON datasource.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val AVRO_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.avro.filterPushdown.enabled") + .doc("When true, enable filter pushdown to Avro datasource.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val ADD_PARTITION_BATCH_SIZE = + buildConf("spark.sql.addPartitionInBatch.size") + .internal() + .doc("The number of partitions to be handled in one turn when use " + + "`AlterTableAddPartitionCommand` or `RepairTableCommand` to add partitions into table. " + + "The smaller batch size is, the less memory is required for the real handler, e.g. " + + "Hive Metastore.") + .version("3.0.0") + .intConf + .checkValue(_ > 0, "The value of spark.sql.addPartitionInBatch.size must be positive") + .createWithDefault(100) + + val LEGACY_ALLOW_HASH_ON_MAPTYPE = buildConf("spark.sql.legacy.allowHashOnMapType") + .internal() + .doc("When set to true, hash expressions can be applied on elements of MapType. Otherwise, " + + "an analysis exception will be thrown.") + .version("3.0.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_INTEGER_GROUPING_ID = + buildConf("spark.sql.legacy.integerGroupingId") + .internal() + .doc("When true, grouping_id() returns int values instead of long values.") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + val PARQUET_INT96_REBASE_MODE_IN_WRITE = + buildConf("spark.sql.parquet.int96RebaseModeInWrite") + .internal() + .doc("When LEGACY, Spark will rebase INT96 timestamps from Proleptic Gregorian calendar to " + + "the legacy hybrid (Julian + Gregorian) calendar when writing Parquet files. " + + "When CORRECTED, Spark will not do rebase and write the timestamps as it is. " + + "When EXCEPTION, which is the default, Spark will fail the writing if it sees ancient " + + "timestamps that are ambiguous between the two calendars.") + .version("3.1.0") + .withAlternative("spark.sql.legacy.parquet.int96RebaseModeInWrite") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) + + val PARQUET_REBASE_MODE_IN_WRITE = + buildConf("spark.sql.parquet.datetimeRebaseModeInWrite") + .internal() + .doc("When LEGACY, Spark will rebase dates/timestamps from Proleptic Gregorian calendar " + + "to the legacy hybrid (Julian + Gregorian) calendar when writing Parquet files. " + + "When CORRECTED, Spark will not do rebase and write the dates/timestamps as it is. " + + "When EXCEPTION, which is the default, Spark will fail the writing if it sees " + + "ancient dates/timestamps that are ambiguous between the two calendars. " + + "This config influences on writes of the following parquet logical types: DATE, " + + "TIMESTAMP_MILLIS, TIMESTAMP_MICROS. The INT96 type has the separate config: " + + s"${PARQUET_INT96_REBASE_MODE_IN_WRITE.key}.") + .version("3.0.0") + .withAlternative("spark.sql.legacy.parquet.datetimeRebaseModeInWrite") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) + + val PARQUET_INT96_REBASE_MODE_IN_READ = + buildConf("spark.sql.parquet.int96RebaseModeInRead") + .internal() + .doc("When LEGACY, Spark will rebase INT96 timestamps from the legacy hybrid (Julian + " + + "Gregorian) calendar to Proleptic Gregorian calendar when reading Parquet files. " + + "When CORRECTED, Spark will not do rebase and read the timestamps as it is. " + + "When EXCEPTION, which is the default, Spark will fail the reading if it sees ancient " + + "timestamps that are ambiguous between the two calendars. This config is only effective " + + "if the writer info (like Spark, Hive) of the Parquet files is unknown.") + .version("3.1.0") + .withAlternative("spark.sql.legacy.parquet.int96RebaseModeInRead") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) + + val PARQUET_REBASE_MODE_IN_READ = + buildConf("spark.sql.parquet.datetimeRebaseModeInRead") + .internal() + .doc("When LEGACY, Spark will rebase dates/timestamps from the legacy hybrid (Julian + " + + "Gregorian) calendar to Proleptic Gregorian calendar when reading Parquet files. " + + "When CORRECTED, Spark will not do rebase and read the dates/timestamps as it is. " + + "When EXCEPTION, which is the default, Spark will fail the reading if it sees " + + "ancient dates/timestamps that are ambiguous between the two calendars. This config is " + + "only effective if the writer info (like Spark, Hive) of the Parquet files is unknown. " + + "This config influences on reads of the following parquet logical types: DATE, " + + "TIMESTAMP_MILLIS, TIMESTAMP_MICROS. The INT96 type has the separate config: " + + s"${PARQUET_INT96_REBASE_MODE_IN_READ.key}.") + .version("3.0.0") + .withAlternative("spark.sql.legacy.parquet.datetimeRebaseModeInRead") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) + + val AVRO_REBASE_MODE_IN_WRITE = + buildConf("spark.sql.avro.datetimeRebaseModeInWrite") + .internal() + .doc("When LEGACY, Spark will rebase dates/timestamps from Proleptic Gregorian calendar " + + "to the legacy hybrid (Julian + Gregorian) calendar when writing Avro files. " + + "When CORRECTED, Spark will not do rebase and write the dates/timestamps as it is. " + + "When EXCEPTION, which is the default, Spark will fail the writing if it sees " + + "ancient dates/timestamps that are ambiguous between the two calendars.") + .version("3.0.0") + .withAlternative("spark.sql.legacy.avro.datetimeRebaseModeInWrite") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) + + val AVRO_REBASE_MODE_IN_READ = + buildConf("spark.sql.avro.datetimeRebaseModeInRead") + .internal() + .doc("When LEGACY, Spark will rebase dates/timestamps from the legacy hybrid (Julian + " + + "Gregorian) calendar to Proleptic Gregorian calendar when reading Avro files. " + + "When CORRECTED, Spark will not do rebase and read the dates/timestamps as it is. " + + "When EXCEPTION, which is the default, Spark will fail the reading if it sees " + + "ancient dates/timestamps that are ambiguous between the two calendars. This config is " + + "only effective if the writer info (like Spark, Hive) of the Avro files is unknown.") + .version("3.0.0") + .withAlternative("spark.sql.legacy.avro.datetimeRebaseModeInRead") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) + .createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString) + + val SCRIPT_TRANSFORMATION_EXIT_TIMEOUT = + buildConf("spark.sql.scriptTransformation.exitTimeoutInSeconds") + .internal() + .doc("Timeout for executor to wait for the termination of transformation script when EOF.") + .version("3.0.0") + .timeConf(TimeUnit.SECONDS) + .checkValue(_ > 0, "The timeout value must be positive") + .createWithDefault(10L) + + val COALESCE_BUCKETS_IN_JOIN_ENABLED = + buildConf("spark.sql.bucketing.coalesceBucketsInJoin.enabled") + .doc("When true, if two bucketed tables with the different number of buckets are joined, " + + "the side with a bigger number of buckets will be coalesced to have the same number " + + "of buckets as the other side. Bigger number of buckets is divisible by the smaller " + + "number of buckets. Bucket coalescing is applied to sort-merge joins and " + + "shuffled hash join. Note: Coalescing bucketed table can avoid unnecessary shuffling " + + "in join, but it also reduces parallelism and could possibly cause OOM for " + + "shuffled hash join.") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + val COALESCE_BUCKETS_IN_JOIN_MAX_BUCKET_RATIO = + buildConf("spark.sql.bucketing.coalesceBucketsInJoin.maxBucketRatio") + .doc("The ratio of the number of two buckets being coalesced should be less than or " + + "equal to this value for bucket coalescing to be applied. This configuration only " + + s"has an effect when '${COALESCE_BUCKETS_IN_JOIN_ENABLED.key}' is set to true.") + .version("3.1.0") + .intConf + .checkValue(_ > 0, "The difference must be positive.") + .createWithDefault(4) + + val BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT = + buildConf("spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit") + .internal() + .doc("The maximum number of partitionings that a HashPartitioning can be expanded to. " + + "This configuration is applicable only for BroadcastHashJoin inner joins and can be " + + "set to '0' to disable this feature.") + .version("3.1.0") + .intConf + .checkValue(_ >= 0, "The value must be non-negative.") + .createWithDefault(8) + + val OPTIMIZE_NULL_AWARE_ANTI_JOIN = + buildConf("spark.sql.optimizeNullAwareAntiJoin") + .internal() + .doc("When true, NULL-aware anti join execution will be planed into " + + "BroadcastHashJoinExec with flag isNullAwareAntiJoin enabled, " + + "optimized from O(M*N) calculation into O(M) calculation " + + "using Hash lookup instead of Looping lookup." + + "Only support for singleColumn NAAJ for now.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val LEGACY_COMPLEX_TYPES_TO_STRING = + buildConf("spark.sql.legacy.castComplexTypesToString.enabled") + .internal() + .doc("When true, maps and structs are wrapped by [] in casting to strings, and " + + "NULL elements of structs/maps/arrays will be omitted while converting to strings. " + + "Otherwise, if this is false, which is the default, maps and structs are wrapped by {}, " + + "and NULL elements will be converted to \"null\".") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_PATH_OPTION_BEHAVIOR = + buildConf("spark.sql.legacy.pathOptionBehavior.enabled") + .internal() + .doc("When true, \"path\" option is overwritten if one path parameter is passed to " + + "DataFrameReader.load(), DataFrameWriter.save(), DataStreamReader.load(), or " + + "DataStreamWriter.start(). Also, \"path\" option is added to the overall paths if " + + "multiple path parameters are passed to DataFrameReader.load()") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_EXTRA_OPTIONS_BEHAVIOR = + buildConf("spark.sql.legacy.extraOptionsBehavior.enabled") + .internal() + .doc("When true, the extra options will be ignored for DataFrameReader.table(). If set it " + + "to false, which is the default, Spark will check if the extra options have the same " + + "key, but the value is different with the table serde properties. If the check passes, " + + "the extra options will be merged with the serde properties as the scan options. " + + "Otherwise, an exception will be thrown.") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + val DISABLED_JDBC_CONN_PROVIDER_LIST = + buildConf("spark.sql.sources.disabledJdbcConnProviderList") + .internal() + .doc("Configures a list of JDBC connection providers, which are disabled. " + + "The list contains the name of the JDBC connection providers separated by comma.") + .version("3.1.0") + .stringConf + .createWithDefault("") + + val LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT = + buildConf("spark.sql.legacy.createHiveTableByDefault") + .internal() + .doc("When set to true, CREATE TABLE syntax without USING or STORED AS will use Hive " + + s"instead of the value of ${DEFAULT_DATA_SOURCE_NAME.key} as the table provider.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + + val LEGACY_CHAR_VARCHAR_AS_STRING = + buildConf("spark.sql.legacy.charVarcharAsString") + .internal() + .doc("When true, Spark will not fail if user uses char and varchar type directly in those" + + " APIs that accept or parse data types as parameters, e.g." + + " `SparkSession.read.schema(...)`, `SparkSession.udf.register(...)` but treat them as" + + " string type as Spark 3.0 and earlier.") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + + val CLI_PRINT_HEADER = + buildConf("spark.sql.cli.print.header") + .doc("When set to true, spark-sql CLI prints the names of the columns in query output.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + + val LEGACY_KEEP_COMMAND_OUTPUT_SCHEMA = + buildConf("spark.sql.legacy.keepCommandOutputSchema") + .internal() + .doc("When true, Spark will keep the output schema of commands such as SHOW DATABASES " + + "unchanged, for v1 catalog and/or table.") + .version("3.0.2") + .booleanConf + .createWithDefault(false) + + val LEGACY_INTERVAL_ENABLED = buildConf("spark.sql.legacy.interval.enabled") + .internal() + .doc("When set to true, Spark SQL uses the mixed legacy interval type `CalendarIntervalType` " + + "instead of the ANSI compliant interval types `YearMonthIntervalType` and " + + "`DayTimeIntervalType`. For instance, the date subtraction expression returns " + + "`CalendarIntervalType` when the SQL config is set to `true` otherwise an ANSI interval.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + + val MAX_CONCURRENT_OUTPUT_FILE_WRITERS = buildConf("spark.sql.maxConcurrentOutputFileWriters") + .internal() + .doc("Maximum number of output file writers to use concurrently. If number of writers " + + "needed reaches this limit, task will sort rest of output then writing them.") + .version("3.2.0") + .intConf + .createWithDefault(0) + + /** + * Holds information about keys that have been deprecated. + * + * @param key The deprecated key. + * @param version Version of Spark where key was deprecated. + * @param comment Additional info regarding to the removed config. For example, + * reasons of config deprecation, what users should use instead of it. + */ + case class DeprecatedConfig(key: String, version: String, comment: String) + + /** + * Maps deprecated SQL config keys to information about the deprecation. + * + * The extra information is logged as a warning when the SQL config is present + * in the user's configuration. + */ + val deprecatedSQLConfigs: Map[String, DeprecatedConfig] = { + val configs = Seq( + DeprecatedConfig( + PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key, "2.4", + "The config allows to switch to the behaviour before Spark 2.4 " + + "and will be removed in the future releases."), + DeprecatedConfig(HIVE_VERIFY_PARTITION_PATH.key, "3.0", + s"This config is replaced by '${SPARK_IGNORE_MISSING_FILES.key}'."), + DeprecatedConfig(ARROW_EXECUTION_ENABLED.key, "3.0", + s"Use '${ARROW_PYSPARK_EXECUTION_ENABLED.key}' instead of it."), + DeprecatedConfig(ARROW_FALLBACK_ENABLED.key, "3.0", + s"Use '${ARROW_PYSPARK_FALLBACK_ENABLED.key}' instead of it."), + DeprecatedConfig(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "3.0", + s"Use '${ADVISORY_PARTITION_SIZE_IN_BYTES.key}' instead of it."), + DeprecatedConfig(OPTIMIZER_METADATA_ONLY.key, "3.0", + "Avoid to depend on this optimization to prevent a potential correctness issue. " + + "If you must use, use 'SparkSessionExtensions' instead to inject it as a custom rule."), + DeprecatedConfig(CONVERT_CTAS.key, "3.1", + s"Set '${LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT.key}' to false instead."), + DeprecatedConfig("spark.sql.sources.schemaStringLengthThreshold", "3.2", + s"Use '${HIVE_TABLE_PROPERTY_LENGTH_THRESHOLD.key}' instead."), + DeprecatedConfig(PARQUET_INT96_REBASE_MODE_IN_WRITE.alternatives.head, "3.2", + s"Use '${PARQUET_INT96_REBASE_MODE_IN_WRITE.key}' instead."), + DeprecatedConfig(PARQUET_INT96_REBASE_MODE_IN_READ.alternatives.head, "3.2", + s"Use '${PARQUET_INT96_REBASE_MODE_IN_READ.key}' instead."), + DeprecatedConfig(PARQUET_REBASE_MODE_IN_WRITE.alternatives.head, "3.2", + s"Use '${PARQUET_REBASE_MODE_IN_WRITE.key}' instead."), + DeprecatedConfig(PARQUET_REBASE_MODE_IN_READ.alternatives.head, "3.2", + s"Use '${PARQUET_REBASE_MODE_IN_READ.key}' instead."), + DeprecatedConfig(AVRO_REBASE_MODE_IN_WRITE.alternatives.head, "3.2", + s"Use '${AVRO_REBASE_MODE_IN_WRITE.key}' instead."), + DeprecatedConfig(AVRO_REBASE_MODE_IN_READ.alternatives.head, "3.2", + s"Use '${AVRO_REBASE_MODE_IN_READ.key}' instead."), + DeprecatedConfig(LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED.key, "3.2", + """Use `.format("avro")` in `DataFrameWriter` or `DataFrameReader` instead."""), + DeprecatedConfig(COALESCE_PARTITIONS_MIN_PARTITION_NUM.key, "3.2", + s"Use '${COALESCE_PARTITIONS_MIN_PARTITION_SIZE.key}' instead.") + ) + + Map(configs.map { cfg => cfg.key -> cfg } : _*) + } + + /** + * Holds information about keys that have been removed. + * + * @param key The removed config key. + * @param version Version of Spark where key was removed. + * @param defaultValue The default config value. It can be used to notice + * users that they set non-default value to an already removed config. + * @param comment Additional info regarding to the removed config. + */ + case class RemovedConfig(key: String, version: String, defaultValue: String, comment: String) + + /** + * The map contains info about removed SQL configs. Keys are SQL config names, + * map values contain extra information like the version in which the config was removed, + * config's default value and a comment. + * + * Please, add a removed SQL configuration property here only when it affects behaviours. + * For example, `spark.sql.variable.substitute.depth` was not added as it virtually + * became no-op later. By this, it makes migrations to new Spark versions painless. + */ + val removedSQLConfigs: Map[String, RemovedConfig] = { + val configs = Seq( + RemovedConfig("spark.sql.fromJsonForceNullableSchema", "3.0.0", "true", + "It was removed to prevent errors like SPARK-23173 for non-default value."), + RemovedConfig( + "spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation", "3.0.0", "false", + "It was removed to prevent loosing of users data for non-default value."), + RemovedConfig("spark.sql.legacy.compareDateTimestampInTimestamp", "3.0.0", "true", + "It was removed to prevent errors like SPARK-23549 for non-default value."), + RemovedConfig("spark.sql.parquet.int64AsTimestampMillis", "3.0.0", "false", + "The config was deprecated since Spark 2.3." + + s"Use '${PARQUET_OUTPUT_TIMESTAMP_TYPE.key}' instead of it."), + RemovedConfig("spark.sql.execution.pandas.respectSessionTimeZone", "3.0.0", "true", + "The non-default behavior is considered as a bug, see SPARK-22395. " + + "The config was deprecated since Spark 2.3."), + RemovedConfig("spark.sql.optimizer.planChangeLog.level", "3.1.0", "trace", + s"Please use `${PLAN_CHANGE_LOG_LEVEL.key}` instead."), + RemovedConfig("spark.sql.optimizer.planChangeLog.rules", "3.1.0", "", + s"Please use `${PLAN_CHANGE_LOG_RULES.key}` instead."), + RemovedConfig("spark.sql.optimizer.planChangeLog.batches", "3.1.0", "", + s"Please use `${PLAN_CHANGE_LOG_BATCHES.key}` instead.") + ) + + Map(configs.map { cfg => cfg.key -> cfg } : _*) + } +} + +/** + * A class that enables the setting and getting of mutable config parameters/hints. + * + * In the presence of a SQLContext, these can be set and queried by passing SET commands + * into Spark SQL's query functions (i.e. sql()). Otherwise, users of this class can + * modify the hints by programmatically calling the setters and getters of this class. + * + * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). + */ +class SQLConf extends Serializable with Logging { + import SQLConf._ + + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ + @transient protected[spark] val settings = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, String]()) + + @transient protected val reader = new ConfigReader(settings) + + /** ************************ Spark SQL Params/Hints ******************* */ + + def analyzerMaxIterations: Int = getConf(ANALYZER_MAX_ITERATIONS) + + def optimizerExcludedRules: Option[String] = getConf(OPTIMIZER_EXCLUDED_RULES) + + def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS) + + def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) + + def optimizerInSetSwitchThreshold: Int = getConf(OPTIMIZER_INSET_SWITCH_THRESHOLD) + + def planChangeLogLevel: String = getConf(PLAN_CHANGE_LOG_LEVEL) + + def planChangeRules: Option[String] = getConf(PLAN_CHANGE_LOG_RULES) + + def planChangeBatches: Option[String] = getConf(PLAN_CHANGE_LOG_BATCHES) + + def dynamicPartitionPruningEnabled: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_ENABLED) + + def dynamicPartitionPruningUseStats: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_USE_STATS) + + def dynamicPartitionPruningFallbackFilterRatio: Double = + getConf(DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO) + + def dynamicPartitionPruningReuseBroadcastOnly: Boolean = + getConf(DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY) + + def runtimeFilterSemiJoinReductionEnabled: Boolean = + getConf(RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED) + + def runtimeFilterBloomFilterEnabled: Boolean = + getConf(RUNTIME_BLOOM_FILTER_ENABLED) + + def runtimeFilterBloomFilterThreshold: Long = + getConf(RUNTIME_BLOOM_FILTER_THRESHOLD) + + def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) + + def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED) + + def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + + def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED) + + def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) + + def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) + + def useDeprecatedKafkaOffsetFetching: Boolean = getConf(USE_DEPRECATED_KAFKA_OFFSET_FETCHING) + + def statefulOperatorCorrectnessCheckEnabled: Boolean = + getConf(STATEFUL_OPERATOR_CHECK_CORRECTNESS_ENABLED) + + def fileStreamSinkMetadataIgnored: Boolean = getConf(FILESTREAM_SINK_METADATA_IGNORED) + + def streamingFileCommitProtocolClass: String = getConf(STREAMING_FILE_COMMIT_PROTOCOL_CLASS) + + def fileSinkLogDeletion: Boolean = getConf(FILE_SINK_LOG_DELETION) + + def fileSinkLogCompactInterval: Int = getConf(FILE_SINK_LOG_COMPACT_INTERVAL) + + def fileSinkLogCleanupDelay: Long = getConf(FILE_SINK_LOG_CLEANUP_DELAY) + + def fileSourceLogDeletion: Boolean = getConf(FILE_SOURCE_LOG_DELETION) + + def fileSourceLogCompactInterval: Int = getConf(FILE_SOURCE_LOG_COMPACT_INTERVAL) + + def fileSourceLogCleanupDelay: Long = getConf(FILE_SOURCE_LOG_CLEANUP_DELAY) + + def streamingSchemaInference: Boolean = getConf(STREAMING_SCHEMA_INFERENCE) + + def streamingPollingDelay: Long = getConf(STREAMING_POLLING_DELAY) + + def streamingNoDataProgressEventInterval: Long = + getConf(STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL) + + def streamingNoDataMicroBatchesEnabled: Boolean = + getConf(STREAMING_NO_DATA_MICRO_BATCHES_ENABLED) + + def streamingMetricsEnabled: Boolean = getConf(STREAMING_METRICS_ENABLED) + + def streamingProgressRetention: Int = getConf(STREAMING_PROGRESS_RETENTION) + + def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) + + def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES) + + def filesMinPartitionNum: Option[Int] = getConf(FILES_MIN_PARTITION_NUM) + + def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) + + def ignoreMissingFiles: Boolean = getConf(IGNORE_MISSING_FILES) + + def maxRecordsPerFile: Long = getConf(MAX_RECORDS_PER_FILE) + + def useCompression: Boolean = getConf(COMPRESS_CACHED) + + def orcCompressionCodec: String = getConf(ORC_COMPRESSION) + + def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED) + + def orcVectorizedReaderBatchSize: Int = getConf(ORC_VECTORIZED_READER_BATCH_SIZE) + + def orcVectorizedReaderNestedColumnEnabled: Boolean = + getConf(ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED) + + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) + + def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) + + def parquetVectorizedReaderBatchSize: Int = getConf(PARQUET_VECTORIZED_READER_BATCH_SIZE) + + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) + + def cacheVectorizedReaderEnabled: Boolean = getConf(CACHE_VECTORIZED_READER_ENABLED) + + def defaultNumShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) + + def numShufflePartitions: Int = { + if (adaptiveExecutionEnabled && coalesceShufflePartitionsEnabled) { + getConf(COALESCE_PARTITIONS_INITIAL_PARTITION_NUM).getOrElse(defaultNumShufflePartitions) + } else { + defaultNumShufflePartitions + } + } + + def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + + def adaptiveExecutionLogLevel: String = getConf(ADAPTIVE_EXECUTION_LOG_LEVEL) + + def fetchShuffleBlocksInBatch: Boolean = getConf(FETCH_SHUFFLE_BLOCKS_IN_BATCH) + + def nonEmptyPartitionRatioForBroadcastJoin: Double = + getConf(NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN) + + def coalesceShufflePartitionsEnabled: Boolean = getConf(COALESCE_PARTITIONS_ENABLED) + + def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN) + + def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY) + + def streamingMaintenanceInterval: Long = getConf(STREAMING_MAINTENANCE_INTERVAL) + + def stateStoreCompressionCodec: String = getConf(STATE_STORE_COMPRESSION_CODEC) + + def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) + + def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + + def parquetFilterPushDownTimestamp: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED) + + def parquetFilterPushDownDecimal: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED) + + def parquetFilterPushDownStringStartWith: Boolean = + getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) + + def parquetFilterPushDownInFilterThreshold: Int = + getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) + + def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED) + + def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) + + def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) + + def metastorePartitionPruningInSetThreshold: Int = + getConf(HIVE_METASTORE_PARTITION_PRUNING_INSET_THRESHOLD) + + def manageFilesourcePartitions: Boolean = getConf(HIVE_MANAGE_FILESOURCE_PARTITIONS) + + def filesourcePartitionFileCacheSize: Long = getConf(HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE) + + def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value = + HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE)) + + def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) + + def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY) + + def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + + def wholeStageUseIdInClassName: Boolean = getConf(WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME) + + def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) + + def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) + + def codegenComments: Boolean = getConf(StaticSQLConf.CODEGEN_COMMENTS) + + def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) + + def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) + + def methodSplitThreshold: Int = getConf(CODEGEN_METHOD_SPLIT_THRESHOLD) + + def wholeStageSplitConsumeFuncByOperator: Boolean = + getConf(WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR) + + def tableRelationCacheSize: Int = + getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) + + def codegenCacheMaxEntries: Int = getConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES) + + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) + + def subqueryReuseEnabled: Boolean = getConf(SUBQUERY_REUSE_ENABLED) + + def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) + + def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + + def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + + def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR) + + def stringRedactionPattern: Option[Regex] = getConf(SQL_STRING_REDACTION_PATTERN) + + def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) + + def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) + + def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + + def streamingSessionWindowMergeSessionInLocalPartition: Boolean = + getConf(STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION) + + def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) + + def uiExplainMode: String = getConf(UI_EXPLAIN_MODE) + + def addSingleFileInAddFile: Boolean = getConf(LEGACY_ADD_SINGLE_FILE_IN_ADD_FILE) + + def legacyMsSqlServerNumericMappingEnabled: Boolean = + getConf(LEGACY_MSSQLSERVER_NUMERIC_MAPPING_ENABLED) + + def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value = { + LegacyBehaviorPolicy.withName(getConf(SQLConf.LEGACY_TIME_PARSER_POLICY)) + } + + def broadcastHashJoinOutputPartitioningExpandLimit: Int = + getConf(BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT) + + /** + * Returns the [[Resolver]] for the current configuration, which can be used to determine if two + * identifiers are equal. + */ + def resolver: Resolver = { + if (caseSensitiveAnalysis) { + org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution + } else { + org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + } + } + + /** + * Returns the error handler for handling hint errors. + */ + def hintErrorHandler: HintErrorHandler = HintErrorLogger + + def subexpressionEliminationEnabled: Boolean = + getConf(SUBEXPRESSION_ELIMINATION_ENABLED) + + def subexpressionEliminationCacheMaxEntries: Int = + getConf(SUBEXPRESSION_ELIMINATION_CACHE_MAX_ENTRIES) + + def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD) + + def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) + + def advancedPartitionPredicatePushdownEnabled: Boolean = + getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) + + def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) + + def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED) + + def isParquetSchemaMergingEnabled: Boolean = getConf(PARQUET_SCHEMA_MERGING_ENABLED) + + def isParquetSchemaRespectSummaries: Boolean = getConf(PARQUET_SCHEMA_RESPECT_SUMMARIES) + + def parquetOutputCommitterClass: String = getConf(PARQUET_OUTPUT_COMMITTER_CLASS) + + def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) + + def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + + def isParquetINT96TimestampConversion: Boolean = getConf(PARQUET_INT96_TIMESTAMP_CONVERSION) + + def parquetOutputTimestampType: ParquetOutputTimestampType.Value = { + ParquetOutputTimestampType.withName(getConf(PARQUET_OUTPUT_TIMESTAMP_TYPE)) + } + + def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) + + def parquetRecordFilterEnabled: Boolean = getConf(PARQUET_RECORD_FILTER_ENABLED) + + def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) + + def inMemoryTableScanStatisticsEnabled: Boolean = getConf(IN_MEMORY_TABLE_SCAN_STATISTICS_ENABLED) + + def offHeapColumnVectorEnabled: Boolean = getConf(COLUMN_VECTOR_OFFHEAP_ENABLED) + + def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) + + def broadcastTimeout: Long = { + val timeoutValue = getConf(BROADCAST_TIMEOUT) + if (timeoutValue < 0) Long.MaxValue else timeoutValue + } + + def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) + + def convertCTAS: Boolean = getConf(CONVERT_CTAS) + + def partitionColumnTypeInferenceEnabled: Boolean = + getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) + + def fileCommitProtocolClass: String = getConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS) + + def parallelPartitionDiscoveryThreshold: Int = + getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) + + def parallelPartitionDiscoveryParallelism: Int = + getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_PARALLELISM) + + def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) + + def bucketingMaxBuckets: Int = getConf(SQLConf.BUCKETING_MAX_BUCKETS) + + def autoBucketedScanEnabled: Boolean = getConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED) + + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = + getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) + + def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) + + def dataFramePivotMaxValues: Int = getConf(DATAFRAME_PIVOT_MAX_VALUES) + + def runSQLonFile: Boolean = getConf(RUN_SQL_ON_FILES) + + def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP) + + def enableVectorizedHashMap: Boolean = getConf(ENABLE_VECTORIZED_HASH_MAP) + + def useObjectHashAggregation: Boolean = getConf(USE_OBJECT_HASH_AGG) + + def objectAggSortBasedFallbackThreshold: Int = getConf(OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD) + + def variableSubstituteEnabled: Boolean = getConf(VARIABLE_SUBSTITUTE_ENABLED) + + def warehousePath: String = new Path(getConf(StaticSQLConf.WAREHOUSE_PATH)).toString + + def hiveThriftServerSingleSession: Boolean = + getConf(StaticSQLConf.HIVE_THRIFT_SERVER_SINGLESESSION) + + def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) + + def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) + + def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES) + + def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + + def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) + + def jsonGeneratorIgnoreNullFields: Boolean = getConf(SQLConf.JSON_GENERATOR_IGNORE_NULL_FIELDS) + + def jsonExpressionOptimization: Boolean = getConf(SQLConf.JSON_EXPRESSION_OPTIMIZATION) + + def csvExpressionOptimization: Boolean = getConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION) + + def parallelFileListingInStatsComputation: Boolean = + getConf(SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION) + + def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) + + def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES) + + def ndvMaxError: Double = getConf(NDV_MAX_ERROR) + + def histogramEnabled: Boolean = getConf(HISTOGRAM_ENABLED) + + def histogramNumBins: Int = getConf(HISTOGRAM_NUM_BINS) + + def percentileAccuracy: Int = getConf(PERCENTILE_ACCURACY) + + def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) + + def planStatsEnabled: Boolean = getConf(SQLConf.PLAN_STATS_ENABLED) + + def autoSizeUpdateEnabled: Boolean = getConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED) + + def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) + + def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) + + def joinReorderCardWeight: Double = getConf(SQLConf.JOIN_REORDER_CARD_WEIGHT) + + def joinReorderDPStarFilter: Boolean = getConf(SQLConf.JOIN_REORDER_DP_STAR_FILTER) + + def windowExecBufferInMemoryThreshold: Int = getConf(WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD) + + def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) + + def sessionWindowBufferInMemoryThreshold: Int = getConf(SESSION_WINDOW_BUFFER_IN_MEMORY_THRESHOLD) + + def sessionWindowBufferSpillThreshold: Int = getConf(SESSION_WINDOW_BUFFER_SPILL_THRESHOLD) + + def sortMergeJoinExecBufferInMemoryThreshold: Int = + getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD) + + def sortMergeJoinExecBufferSpillThreshold: Int = + getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD) + + def cartesianProductExecBufferInMemoryThreshold: Int = + getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD) + + def cartesianProductExecBufferSpillThreshold: Int = + getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD) + + def codegenSplitAggregateFunc: Boolean = getConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC) + + def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) + + def useCurrentSQLConfigsForView: Boolean = getConf(SQLConf.USE_CURRENT_SQL_CONFIGS_FOR_VIEW) + + def storeAnalyzedPlanForView: Boolean = getConf(SQLConf.STORE_ANALYZED_PLAN_FOR_VIEW) + + def allowAutoGeneratedAliasForView: Boolean = getConf(SQLConf.ALLOW_AUTO_GENERATED_ALIAS_FOR_VEW) + + def allowStarWithSingleTableIdentifierInCount: Boolean = + getConf(SQLConf.ALLOW_STAR_WITH_SINGLE_TABLE_IDENTIFIER_IN_COUNT) + + def allowNonEmptyLocationInCTAS: Boolean = + getConf(SQLConf.ALLOW_NON_EMPTY_LOCATION_IN_CTAS) + + def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION) + + def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO) + + def supportQuotedRegexColumnName: Boolean = getConf(SUPPORT_QUOTED_REGEX_COLUMN_NAME) + + def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) + + def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED) + + def arrowPySparkSelfDestructEnabled: Boolean = getConf(ARROW_PYSPARK_SELF_DESTRUCT_ENABLED) + + def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED) + + def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED) + + def arrowPySparkFallbackEnabled: Boolean = getConf(ARROW_PYSPARK_FALLBACK_ENABLED) + + def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + + def pandasUDFBufferSize: Int = getConf(PANDAS_UDF_BUFFER_SIZE) + + def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIEID_TRACEBACK) + + def pandasGroupedMapAssignColumnsByName: Boolean = + getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME) + + def arrowSafeTypeConversion: Boolean = getConf(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION) + + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + + def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) + + def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) + + def continuousStreamingEpochBacklogQueueSize: Int = + getConf(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE) + + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) + + def continuousStreamingExecutorPollIntervalMs: Long = + getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) + + def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) + + def disabledV2StreamingMicroBatchReaders: String = + getConf(DISABLED_V2_STREAMING_MICROBATCH_READERS) + + def fastFailFileFormatOutput: Boolean = getConf(FASTFAIL_ON_FILEFORMAT_OUTPUT) + + def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + + def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + + def validatePartitionColumns: Boolean = getConf(VALIDATE_PARTITION_COLUMNS) + + def partitionOverwriteMode: PartitionOverwriteMode.Value = + PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + + def storeAssignmentPolicy: StoreAssignmentPolicy.Value = + StoreAssignmentPolicy.withName(getConf(STORE_ASSIGNMENT_POLICY)) + + def ansiEnabled: Boolean = getConf(ANSI_ENABLED) + + def timestampType: AtomicType = getConf(TIMESTAMP_TYPE) match { + // SPARK-36227: Remove TimestampNTZ type support in Spark 3.2 with minimal code changes. + // The configuration `TIMESTAMP_TYPE` is only effective for testing in Spark 3.2. + case "TIMESTAMP_NTZ" if Utils.isTesting => + TimestampNTZType + + case _ => + // For historical reason, the TimestampType maps to TIMESTAMP WITH LOCAL TIME ZONE + TimestampType + } + + def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED) + + def serializerNestedSchemaPruningEnabled: Boolean = + getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED) + + def nestedPruningOnExpressions: Boolean = getConf(NESTED_PRUNING_ON_EXPRESSIONS) + + def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + + def legacySizeOfNull: Boolean = { + // size(null) should return null under ansi mode. + getConf(SQLConf.LEGACY_SIZE_OF_NULL) && !getConf(ANSI_ENABLED) + } + + def isReplEagerEvalEnabled: Boolean = getConf(SQLConf.REPL_EAGER_EVAL_ENABLED) + + def replEagerEvalMaxNumRows: Int = getConf(SQLConf.REPL_EAGER_EVAL_MAX_NUM_ROWS) + + def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE) + + def avroCompressionCodec: String = getConf(SQLConf.AVRO_COMPRESSION_CODEC) + + def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL) + + def replaceDatabricksSparkAvroEnabled: Boolean = + getConf(SQLConf.LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED) + + def setOpsPrecedenceEnforced: Boolean = getConf(SQLConf.LEGACY_SETOPS_PRECEDENCE_ENABLED) + + def exponentLiteralAsDecimalEnabled: Boolean = + getConf(SQLConf.LEGACY_EXPONENT_LITERAL_AS_DECIMAL_ENABLED) + + def allowNegativeScaleOfDecimalEnabled: Boolean = + getConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED) + + def legacyStatisticalAggregate: Boolean = getConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE) + + def truncateTableIgnorePermissionAcl: Boolean = + getConf(SQLConf.TRUNCATE_TABLE_IGNORE_PERMISSION_ACL) + + def nameNonStructGroupingKeyAsValue: Boolean = + getConf(SQLConf.NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE) + + def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS) + + def maxPlanStringLength: Int = getConf(SQLConf.MAX_PLAN_STRING_LENGTH).toInt + + def maxMetadataStringLength: Int = getConf(SQLConf.MAX_METADATA_STRING_LENGTH) + + def setCommandRejectsSparkCoreConfs: Boolean = + getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS) + + def castDatetimeToString: Boolean = getConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING) + + def ignoreDataLocality: Boolean = getConf(SQLConf.IGNORE_DATA_LOCALITY) + + def csvFilterPushDown: Boolean = getConf(CSV_FILTER_PUSHDOWN_ENABLED) + + def jsonFilterPushDown: Boolean = getConf(JSON_FILTER_PUSHDOWN_ENABLED) + + def avroFilterPushDown: Boolean = getConf(AVRO_FILTER_PUSHDOWN_ENABLED) + + def integerGroupingIdEnabled: Boolean = getConf(SQLConf.LEGACY_INTEGER_GROUPING_ID) + + def metadataCacheTTL: Long = getConf(StaticSQLConf.METADATA_CACHE_TTL_SECONDS) + + def coalesceBucketsInJoinEnabled: Boolean = getConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED) + + def coalesceBucketsInJoinMaxBucketRatio: Int = + getConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_MAX_BUCKET_RATIO) + + def optimizeNullAwareAntiJoin: Boolean = + getConf(SQLConf.OPTIMIZE_NULL_AWARE_ANTI_JOIN) + + def legacyPathOptionBehavior: Boolean = getConf(SQLConf.LEGACY_PATH_OPTION_BEHAVIOR) + + def disabledJdbcConnectionProviders: String = getConf(SQLConf.DISABLED_JDBC_CONN_PROVIDER_LIST) + + def charVarcharAsString: Boolean = getConf(SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING) + + def cliPrintHeader: Boolean = getConf(SQLConf.CLI_PRINT_HEADER) + + def legacyIntervalEnabled: Boolean = getConf(LEGACY_INTERVAL_ENABLED) + + def decorrelateInnerQueryEnabled: Boolean = getConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED) + + def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS) + + /** ********************** SQLConf functionality methods ************ */ + + /** Set Spark SQL configuration properties. */ + def setConf(props: Properties): Unit = settings.synchronized { + props.asScala.foreach { case (k, v) => setConfString(k, v) } + } + + /** Set the given Spark SQL configuration property using a `string` value. */ + def setConfString(key: String, value: String): Unit = { + require(key != null, "key cannot be null") + require(value != null, s"value cannot be null for key: $key") + val entry = getConfigEntry(key) + if (entry != null) { + // Only verify configs in the SQLConf object + entry.valueConverter(value) + } + setConfWithCheck(key, value) + } + + /** Set the given Spark SQL configuration property. */ + def setConf[T](entry: ConfigEntry[T], value: T): Unit = { + require(entry != null, "entry cannot be null") + require(value != null, s"value cannot be null for key: ${entry.key}") + require(containsConfigEntry(entry), s"$entry is not registered") + setConfWithCheck(entry.key, entry.stringConverter(value)) + } + + /** Return the value of Spark SQL configuration property for the given key. */ + @throws[NoSuchElementException]("if key is not set") + def getConfString(key: String): String = { + Option(settings.get(key)). + orElse { + // Try to use the default value + Option(getConfigEntry(key)).map { e => e.stringConverter(e.readFrom(reader)) } + }. + getOrElse(throw QueryExecutionErrors.noSuchElementExceptionError(key)) + } + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. This is useful when `defaultValue` in ConfigEntry is not the + * desired one. + */ + def getConf[T](entry: ConfigEntry[T], defaultValue: T): T = { + require(containsConfigEntry(entry), s"$entry is not registered") + Option(settings.get(entry.key)).map(entry.valueConverter).getOrElse(defaultValue) + } + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue` in [[ConfigEntry]]. + */ + def getConf[T](entry: ConfigEntry[T]): T = { + require(containsConfigEntry(entry), s"$entry is not registered") + entry.readFrom(reader) + } + + /** + * Return the value of an optional Spark SQL configuration property for the given key. If the key + * is not set yet, returns None. + */ + def getConf[T](entry: OptionalConfigEntry[T]): Option[T] = { + require(containsConfigEntry(entry), s"$entry is not registered") + entry.readFrom(reader) + } + + /** + * Return the `string` value of Spark SQL configuration property for the given key. If the key is + * not set yet, return `defaultValue`. + */ + def getConfString(key: String, defaultValue: String): String = { + Option(settings.get(key)).getOrElse { + // If the key is not set, need to check whether the config entry is registered and is + // a fallback conf, so that we can check its parent. + getConfigEntry(key) match { + case e: FallbackConfigEntry[_] => + getConfString(e.fallback.key, defaultValue) + case e: ConfigEntry[_] if defaultValue != null && defaultValue != ConfigEntry.UNDEFINED => + // Only verify configs in the SQLConf object + e.valueConverter(defaultValue) + defaultValue + case _ => + defaultValue + } + } + } + + private var definedConfsLoaded = false + /** + * Init [[StaticSQLConf]] and [[org.apache.spark.sql.hive.HiveUtils]] so that all the defined + * SQL Configurations will be registered to SQLConf + */ + private def loadDefinedConfs(): Unit = { + if (!definedConfsLoaded) { + definedConfsLoaded = true + // Force to register static SQL configurations + StaticSQLConf + try { + // Force to register SQL configurations from Hive module + val symbol = ScalaReflection.mirror.staticModule("org.apache.spark.sql.hive.HiveUtils") + ScalaReflection.mirror.reflectModule(symbol).instance + } catch { + case NonFatal(e) => + logWarning("SQL configurations from Hive module is not loaded", e) + } + } + } + + /** + * Return all the configuration properties that have been set (i.e. not the default). + * This creates a new copy of the config properties in the form of a Map. + */ + def getAllConfs: immutable.Map[String, String] = + settings.synchronized { settings.asScala.toMap } + + /** + * Return all the configuration definitions that have been defined in [[SQLConf]]. Each + * definition contains key, defaultValue and doc. + */ + def getAllDefinedConfs: Seq[(String, String, String, String)] = { + loadDefinedConfs() + getConfigEntries().asScala.filter(_.isPublic).map { entry => + val displayValue = Option(getConfString(entry.key, null)).getOrElse(entry.defaultValueString) + (entry.key, displayValue, entry.doc, entry.version) + }.toSeq + } + + /** + * Redacts the given option map according to the description of SQL_OPTIONS_REDACTION_PATTERN. + */ + def redactOptions[K, V](options: Map[K, V]): Map[K, V] = { + redactOptions(options.toSeq).toMap + } + + /** + * Redacts the given option map according to the description of SQL_OPTIONS_REDACTION_PATTERN. + */ + def redactOptions[K, V](options: Seq[(K, V)]): Seq[(K, V)] = { + val regexes = Seq( + getConf(SQL_OPTIONS_REDACTION_PATTERN), + SECRET_REDACTION_PATTERN.readFrom(reader)) + + regexes.foldLeft(options) { case (opts, r) => Utils.redact(Some(r), opts) } + } + + /** + * Return whether a given key is set in this [[SQLConf]]. + */ + def contains(key: String): Boolean = { + settings.containsKey(key) + } + + /** + * Logs a warning message if the given config key is deprecated. + */ + private def logDeprecationWarning(key: String): Unit = { + SQLConf.deprecatedSQLConfigs.get(key).foreach { + case DeprecatedConfig(configName, version, comment) => + logWarning( + s"The SQL config '$configName' has been deprecated in Spark v$version " + + s"and may be removed in the future. $comment") + } + } + + private def requireDefaultValueOfRemovedConf(key: String, value: String): Unit = { + SQLConf.removedSQLConfigs.get(key).foreach { + case RemovedConfig(configName, version, defaultValue, comment) => + if (value != defaultValue) { + throw QueryCompilationErrors.configRemovedInVersionError(configName, version, comment) + } + } + } + + protected def setConfWithCheck(key: String, value: String): Unit = { + logDeprecationWarning(key) + requireDefaultValueOfRemovedConf(key, value) + settings.put(key, value) + } + + def unsetConf(key: String): Unit = { + logDeprecationWarning(key) + settings.remove(key) + } + + def unsetConf(entry: ConfigEntry[_]): Unit = { + unsetConf(entry.key) + } + + def clear(): Unit = { + settings.clear() + } + + override def clone(): SQLConf = { + val result = new SQLConf + getAllConfs.foreach { + case(k, v) => if (v ne null) result.setConfString(k, v) + } + result + } + + // For test only + def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + val cloned = clone() + entries.foreach { + case (entry, value) => cloned.setConfString(entry.key, value.toString) + } + cloned + } + + def isModifiable(key: String): Boolean = { + containsConfigKey(key) && !isStaticConfigKey(key) + } +}