diff --git a/.github/ISSUE_TEMPLATE/code-contrib-task.yml b/.github/ISSUE_TEMPLATE/code-contrib-task.yml new file mode 100644 index 00000000000..3191e4fe48d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/code-contrib-task.yml @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# This is a dedicated issue template for 2023 Kyuubi Code Contribution Program, all proposed +# tasks will be listed at https://github.com/orgs/apache/projects/296 after approval +# +name: 2023 Kyuubi Code Contribution Task +title: "[TASK][] " +description: Propose a task for 2023 Kyuubi Code Contribution Program +labels: [ "hacktoberfest" ] +body: + - type: markdown + attributes: + value: | + You are very welcome to propose new task for 2023 Kyuubi Code Contribution Program. + Your brilliant ideas keep Apache Kyuubi evolving. + Please replace the placeholder `` in the issue title with one of the following options: + - TRIVIAL - it's usually for new contributors to learn the contributor process, e.g. how to cut branch, + how to use GitHub to send PR, how to response with reviewers, the contributor should not stay at this + stage too long. + - EASY - tasks like minor bugs, or simple features without requirements of knowledge for whole Kyuubi + architecture. + - MEDIUM - tasks typical requires that contributors have knowledge on one or more Kyuubi components, + normally, unit tests and integration tests is also required to verify the implementations. + - CHALLENGE - tasks requires that contributors have deep knowledge on one or more Kyuubi components, + have good logical thinking and the ability to solve complex problems, be proficient in programming + skills or algorithms + + - type: checkboxes + attributes: + label: Code of Conduct + description: The Code of Conduct helps create a safe space for everyone. We require that everyone agrees to it. + options: + - label: > + I agree to follow this project's [Code of Conduct](https://www.apache.org/foundation/policies/conduct) + required: true + + - type: checkboxes + attributes: + label: Search before creating + options: + - label: > + I have searched in the [task list](https://github.com/orgs/apache/projects/296) and found no similar + tasks. + required: true + + - type: checkboxes + attributes: + label: Mentor + description: Mentor is required for MEDIUM and CHALLENGE tasks, to guide contributors to complete the task. + options: + - label: > + I have sufficient knowledge and experience of this task, and I volunteer to be the mentor of this task + to guide contributors to complete the task. + required: false + + - type: textarea + attributes: + label: Skill requirements + description: Which stills are required for contributors who want to take this task? + placeholder: | + e.g. + - Basic knowledge on Scala Programing Language + - Familiar with Apache Maven, Docker and GitHub Action + - Basic knowledge on network programing and Apache Thrift RPC framework + - Familiar with Apache Spark + - ... + validations: + required: true + + - type: textarea + attributes: + label: Background and Goals + description: What's the current problem, and what's the final status should be after the task is completed? + placeholder: > + Please describe the background and your goal for requesting this task. + validations: + required: true + + - type: textarea + attributes: + label: Implementation steps + description: How could it be implemented? + placeholder: > + Please list the implementation steps in as much detail as possible so that contributors who meet + the skill requirements could complete the task quickly and independently. + validations: + required: true + + - type: textarea + attributes: + label: Additional context + placeholder: > + Anything else that related to this task that the contributors need to know. + validations: + required: false + + - type: markdown + attributes: + value: "Thanks for taking the time to fill out this task form!" diff --git a/.github/workflows/license.yml b/.github/workflows/license.yml index 91c53a7a173..55ef485f8fe 100644 --- a/.github/workflows/license.yml +++ b/.github/workflows/license.yml @@ -45,7 +45,7 @@ jobs: - run: >- build/mvn org.apache.rat:apache-rat-plugin:check -Ptpcds -Pspark-block-cleaner -Pkubernetes-it - -Pspark-3.1 -Pspark-3.2 -Pspark-3.3 -Pspark-3.4 + -Pspark-3.1 -Pspark-3.2 -Pspark-3.3 -Pspark-3.4 -Pspark-3.5 - name: Upload rat report if: failure() uses: actions/upload-artifact@v3 diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml index 3b85530d44a..7c442dd0f48 100644 --- a/.github/workflows/master.yml +++ b/.github/workflows/master.yml @@ -52,6 +52,7 @@ jobs: - '3.2' - '3.3' - '3.4' + - '3.5' spark-archive: [""] exclude-tags: [""] comment: ["normal"] @@ -169,16 +170,18 @@ jobs: **/target/unit-tests.log **/kyuubi-spark-sql-engine.log* - scala213: - name: Scala Compilation Test + scala-test: + name: Scala Test runs-on: ubuntu-22.04 strategy: fail-fast: false matrix: - java: - - '8' scala: - '2.13' + java: + - '8' + spark: + - '3.4' steps: - uses: actions/checkout@v3 - name: Tune Runner VM @@ -192,14 +195,24 @@ jobs: check-latest: false - name: Setup Maven uses: ./.github/actions/setup-maven + - name: Cache Engine Archives + uses: ./.github/actions/cache-engine-archives - name: Build on Scala ${{ matrix.scala }} run: | - MODULES='!externals/kyuubi-flink-sql-engine' - ./build/mvn clean install -pl ${MODULES} -am \ - -DskipTests -Pflink-provided,hive-provided,spark-provided \ - -Pjava-${{ matrix.java }} \ - -Pscala-${{ matrix.scala }} \ - -Pspark-3.3 + TEST_MODULES="!externals/kyuubi-flink-sql-engine,!integration-tests/kyuubi-flink-it" + ./build/mvn clean install ${MVN_OPT} -pl ${TEST_MODULES} -am \ + -Pscala-${{ matrix.scala }} -Pjava-${{ matrix.java }} -Pspark-${{ matrix.spark }} + - name: Upload test logs + if: failure() + uses: actions/upload-artifact@v3 + with: + name: unit-tests-log-scala-${{ matrix.scala }}-java-${{ matrix.java }}-spark-${{ matrix.spark }} + path: | + **/target/unit-tests.log + **/kyuubi-spark-sql-engine.log* + **/kyuubi-spark-batch-submit.log* + **/kyuubi-jdbc-engine.log* + **/kyuubi-hive-sql-engine.log* flink-it: name: Flink Test diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 6f575302ea1..21cacbc1de7 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -69,6 +69,7 @@ jobs: build/mvn clean install ${MVN_OPT} -pl extensions/spark/kyuubi-extension-spark-3-1 -Pspark-3.1 build/mvn clean install ${MVN_OPT} -pl extensions/spark/kyuubi-extension-spark-3-3,extensions/spark/kyuubi-spark-connector-hive -Pspark-3.3 build/mvn clean install ${MVN_OPT} -pl extensions/spark/kyuubi-extension-spark-3-4 -Pspark-3.4 + build/mvn clean install ${MVN_OPT} -pl extensions/spark/kyuubi-extension-spark-3-5 -Pspark-3.5 - name: Scalastyle with maven id: scalastyle-check diff --git a/build/dist b/build/dist index b81a2661ece..df9498008cb 100755 --- a/build/dist +++ b/build/dist @@ -335,7 +335,7 @@ if [[ -f "$KYUUBI_HOME/tools/spark-block-cleaner/target/spark-block-cleaner_${SC fi # Copy Kyuubi Spark extension -SPARK_EXTENSION_VERSIONS=('3-1' '3-2' '3-3' '3-4') +SPARK_EXTENSION_VERSIONS=('3-1' '3-2' '3-3' '3-4' '3-5') # shellcheck disable=SC2068 for SPARK_EXTENSION_VERSION in ${SPARK_EXTENSION_VERSIONS[@]}; do if [[ -f $"$KYUUBI_HOME/extensions/spark/kyuubi-extension-spark-$SPARK_EXTENSION_VERSION/target/kyuubi-extension-spark-${SPARK_EXTENSION_VERSION}_${SCALA_VERSION}-${VERSION}.jar" ]]; then diff --git a/dev/kyuubi-codecov/pom.xml b/dev/kyuubi-codecov/pom.xml index 31b9d27bc03..a5ec582f961 100644 --- a/dev/kyuubi-codecov/pom.xml +++ b/dev/kyuubi-codecov/pom.xml @@ -219,5 +219,15 @@ + + spark-3.5 + + + org.apache.kyuubi + kyuubi-extension-spark-3-5_${scala.binary.version} + ${project.version} + + + diff --git a/dev/reformat b/dev/reformat index 6346e68f68d..31e8f49ad21 100755 --- a/dev/reformat +++ b/dev/reformat @@ -20,7 +20,7 @@ set -x KYUUBI_HOME="$(cd "`dirname "$0"`/.."; pwd)" -PROFILES="-Pflink-provided,hive-provided,spark-provided,spark-block-cleaner,spark-3.4,spark-3.3,spark-3.2,spark-3.1,tpcds" +PROFILES="-Pflink-provided,hive-provided,spark-provided,spark-block-cleaner,spark-3.5,spark-3.4,spark-3.3,spark-3.2,spark-3.1,tpcds" # python style checks rely on `black` in path if ! command -v black &> /dev/null diff --git a/docs/quick_start/quick_start_with_jdbc.md b/docs/quick_start/quick_start_with_jdbc.md index e6f4f705296..abd4fbec4b9 100644 --- a/docs/quick_start/quick_start_with_jdbc.md +++ b/docs/quick_start/quick_start_with_jdbc.md @@ -35,7 +35,7 @@ The driver is available from Maven Central: ## Connect to non-kerberized Kyuubi Server -The below java code is using a keytab file to login and connect to Kyuubi server by JDBC. +The following java code connects directly to the Kyuubi Server by JDBC without using kerberos authentication. ```java package org.apache.kyuubi.examples; @@ -50,7 +50,7 @@ public class KyuubiJDBC { public static void main(String[] args) throws SQLException { try (Connection conn = DriverManager.getConnection(kyuubiJdbcUrl)) { try (Statement stmt = conn.createStatement()) { - try (ResultSet rs = st.executeQuery("show databases")) { + try (ResultSet rs = stmt.executeQuery("show databases")) { while (rs.next()) { System.out.println(rs.getString(1)); } @@ -79,11 +79,11 @@ public class KyuubiJDBCDemo { public static void main(String[] args) throws SQLException { String clientPrincipal = args[0]; // Kerberos principal String clientKeytab = args[1]; // Keytab file location - String serverPrincipal = arg[2]; // Kerberos principal used by Kyuubi Server + String serverPrincipal = args[2]; // Kerberos principal used by Kyuubi Server String kyuubiJdbcUrl = String.format(kyuubiJdbcUrlTemplate, clientPrincipal, clientKeytab, serverPrincipal); try (Connection conn = DriverManager.getConnection(kyuubiJdbcUrl)) { try (Statement stmt = conn.createStatement()) { - try (ResultSet rs = st.executeQuery("show databases")) { + try (ResultSet rs = stmt.executeQuery("show databases")) { while (rs.next()) { System.out.println(rs.getString(1)); } diff --git a/extensions/spark/kyuubi-extension-spark-3-5/pom.xml b/extensions/spark/kyuubi-extension-spark-3-5/pom.xml new file mode 100644 index 00000000000..e78a88a8002 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/pom.xml @@ -0,0 +1,206 @@ + + + + 4.0.0 + + org.apache.kyuubi + kyuubi-parent + 1.9.0-SNAPSHOT + ../../../pom.xml + + + kyuubi-extension-spark-3-5_${scala.binary.version} + jar + Kyuubi Dev Spark Extensions (for Spark 3.5) + https://kyuubi.apache.org/ + + + + org.scala-lang + scala-library + provided + + + + org.apache.spark + spark-sql_${scala.binary.version} + provided + + + + org.apache.spark + spark-hive_${scala.binary.version} + provided + + + + org.apache.hadoop + hadoop-client-api + provided + + + + org.apache.kyuubi + kyuubi-download + ${project.version} + pom + test + + + + org.apache.kyuubi + kyuubi-util-scala_${scala.binary.version} + ${project.version} + test-jar + test + + + + org.apache.spark + spark-core_${scala.binary.version} + test-jar + test + + + + org.apache.spark + spark-catalyst_${scala.binary.version} + test-jar + test + + + + org.scalatestplus + scalacheck-1-17_${scala.binary.version} + test + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + test-jar + test + + + + org.apache.hadoop + hadoop-client-runtime + test + + + + + commons-collections + commons-collections + test + + + + commons-io + commons-io + test + + + + jakarta.xml.bind + jakarta.xml.bind-api + test + + + + org.apache.logging.log4j + log4j-slf4j-impl + test + + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + regex-property + + regex-property + + + spark.home + ${project.basedir}/../../../externals/kyuubi-download/target/${spark.archive.name} + (.+)\.tgz + $1 + + + + + + org.scalatest + scalatest-maven-plugin + + + + ${spark.home} + ${scala.binary.version} + + + + + org.antlr + antlr4-maven-plugin + + true + ${project.basedir}/src/main/antlr4 + + + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + org.apache.kyuubi:* + + + + + + + shade + + package + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/antlr4/org/apache/kyuubi/sql/KyuubiSparkSQL.g4 b/extensions/spark/kyuubi-extension-spark-3-5/src/main/antlr4/org/apache/kyuubi/sql/KyuubiSparkSQL.g4 new file mode 100644 index 00000000000..e52b7f5cfeb --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/antlr4/org/apache/kyuubi/sql/KyuubiSparkSQL.g4 @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +grammar KyuubiSparkSQL; + +@members { + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + } + +tokens { + DELIMITER +} + +singleStatement + : statement EOF + ; + +statement + : OPTIMIZE multipartIdentifier whereClause? zorderClause #optimizeZorder + | .*? #passThrough + ; + +whereClause + : WHERE partitionPredicate = predicateToken + ; + +zorderClause + : ZORDER BY order+=multipartIdentifier (',' order+=multipartIdentifier)* + ; + +// We don't have an expression rule in our grammar here, so we just grab the tokens and defer +// parsing them to later. +predicateToken + : .+? + ; + +multipartIdentifier + : parts+=identifier ('.' parts+=identifier)* + ; + +identifier + : strictIdentifier + ; + +strictIdentifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +nonReserved + : AND + | BY + | FALSE + | DATE + | INTERVAL + | OPTIMIZE + | OR + | TABLE + | TIMESTAMP + | TRUE + | WHERE + | ZORDER + ; + +AND: 'AND'; +BY: 'BY'; +FALSE: 'FALSE'; +DATE: 'DATE'; +INTERVAL: 'INTERVAL'; +OPTIMIZE: 'OPTIMIZE'; +OR: 'OR'; +TABLE: 'TABLE'; +TIMESTAMP: 'TIMESTAMP'; +TRUE: 'TRUE'; +WHERE: 'WHERE'; +ZORDER: 'ZORDER'; + +MINUS: '-'; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +DECIMAL_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT? {isValidDecimal()}? + ; + +DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' .*? '*/' -> channel(HIDDEN) + ; + +WS : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/DropIgnoreNonexistent.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/DropIgnoreNonexistent.scala new file mode 100644 index 00000000000..e33632b8b30 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/DropIgnoreNonexistent.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kyuubi.sql + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunctionName, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.{DropFunction, DropNamespace, LogicalPlan, NoopCommand, UncacheTable} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.command.{AlterTableDropPartitionCommand, DropTableCommand} + +import org.apache.kyuubi.sql.KyuubiSQLConf._ + +case class DropIgnoreNonexistent(session: SparkSession) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (conf.getConf(DROP_IGNORE_NONEXISTENT)) { + plan match { + case i @ AlterTableDropPartitionCommand(_, _, false, _, _) => + i.copy(ifExists = true) + case i @ DropTableCommand(_, false, _, _) => + i.copy(ifExists = true) + case i @ DropNamespace(_, false, _) => + i.copy(ifExists = true) + case UncacheTable(u: UnresolvedRelation, false, _) => + NoopCommand("UNCACHE TABLE", u.multipartIdentifier) + case DropFunction(u: UnresolvedFunctionName, false) => + NoopCommand("DROP FUNCTION", u.multipartIdentifier) + case _ => plan + } + } else { + plan + } + } + +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/InferRebalanceAndSortOrders.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/InferRebalanceAndSortOrders.scala new file mode 100644 index 00000000000..fcbf5c0a122 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/InferRebalanceAndSortOrders.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression, UnaryExpression} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort, SubqueryAlias, View} + +/** + * Infer the columns for Rebalance and Sort to improve the compression ratio. + * + * For example + * {{{ + * INSERT INTO TABLE t PARTITION(p='a') + * SELECT * FROM t1 JOIN t2 on t1.c1 = t2.c1 + * }}} + * the inferred columns are: t1.c1 + */ +object InferRebalanceAndSortOrders { + + type PartitioningAndOrdering = (Seq[Expression], Seq[Expression]) + + private def getAliasMap(named: Seq[NamedExpression]): Map[Expression, Attribute] = { + @tailrec + def throughUnary(e: Expression): Expression = e match { + case u: UnaryExpression if u.deterministic => + throughUnary(u.child) + case _ => e + } + + named.flatMap { + case a @ Alias(child, _) => + Some((throughUnary(child).canonicalized, a.toAttribute)) + case _ => None + }.toMap + } + + def infer(plan: LogicalPlan): Option[PartitioningAndOrdering] = { + def candidateKeys( + input: LogicalPlan, + output: AttributeSet = AttributeSet.empty): Option[PartitioningAndOrdering] = { + input match { + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _, _, _) => + joinType match { + case LeftSemi | LeftAnti | LeftOuter => Some((leftKeys, leftKeys)) + case RightOuter => Some((rightKeys, rightKeys)) + case Inner | FullOuter => + if (output.isEmpty) { + Some((leftKeys ++ rightKeys, leftKeys ++ rightKeys)) + } else { + assert(leftKeys.length == rightKeys.length) + val keys = leftKeys.zip(rightKeys).flatMap { case (left, right) => + if (left.references.subsetOf(output)) { + Some(left) + } else if (right.references.subsetOf(output)) { + Some(right) + } else { + None + } + } + Some((keys, keys)) + } + case _ => None + } + case agg: Aggregate => + val aliasMap = getAliasMap(agg.aggregateExpressions) + Some(( + agg.groupingExpressions.map(p => aliasMap.getOrElse(p.canonicalized, p)), + agg.groupingExpressions.map(o => aliasMap.getOrElse(o.canonicalized, o)))) + case s: Sort => Some((s.order.map(_.child), s.order.map(_.child))) + case p: Project => + val aliasMap = getAliasMap(p.projectList) + candidateKeys(p.child, p.references).map { case (partitioning, ordering) => + ( + partitioning.map(p => aliasMap.getOrElse(p.canonicalized, p)), + ordering.map(o => aliasMap.getOrElse(o.canonicalized, o))) + } + case f: Filter => candidateKeys(f.child, output) + case s: SubqueryAlias => candidateKeys(s.child, output) + case v: View => candidateKeys(v.child, output) + + case _ => None + } + } + + candidateKeys(plan).map { case (partitioning, ordering) => + ( + partitioning.filter(_.references.subsetOf(plan.outputSet)), + ordering.filter(_.references.subsetOf(plan.outputSet))) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/InsertShuffleNodeBeforeJoin.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/InsertShuffleNodeBeforeJoin.scala new file mode 100644 index 00000000000..1a02e8c1e67 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/InsertShuffleNodeBeforeJoin.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{SortExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive.QueryStageExec +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.internal.SQLConf + +import org.apache.kyuubi.sql.KyuubiSQLConf._ + +/** + * Insert shuffle node before join if it doesn't exist to make `OptimizeSkewedJoin` works. + */ +object InsertShuffleNodeBeforeJoin extends Rule[SparkPlan] { + + override def apply(plan: SparkPlan): SparkPlan = { + // this rule has no meaning without AQE + if (!conf.getConf(FORCE_SHUFFLE_BEFORE_JOIN) || + !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED)) { + return plan + } + + val newPlan = insertShuffleBeforeJoin(plan) + if (plan.fastEquals(newPlan)) { + plan + } else { + // make sure the output partitioning and ordering will not be broken. + KyuubiEnsureRequirements.apply(newPlan) + } + } + + // Since spark 3.3, insertShuffleBeforeJoin shouldn't be applied if join is skewed. + private def insertShuffleBeforeJoin(plan: SparkPlan): SparkPlan = plan transformUp { + case smj @ SortMergeJoinExec(_, _, _, _, l, r, isSkewJoin) if !isSkewJoin => + smj.withNewChildren(checkAndInsertShuffle(smj.requiredChildDistribution.head, l) :: + checkAndInsertShuffle(smj.requiredChildDistribution(1), r) :: Nil) + + case shj: ShuffledHashJoinExec if !shj.isSkewJoin => + if (!shj.left.isInstanceOf[Exchange] && !shj.right.isInstanceOf[Exchange]) { + shj.withNewChildren(withShuffleExec(shj.requiredChildDistribution.head, shj.left) :: + withShuffleExec(shj.requiredChildDistribution(1), shj.right) :: Nil) + } else if (!shj.left.isInstanceOf[Exchange]) { + shj.withNewChildren( + withShuffleExec(shj.requiredChildDistribution.head, shj.left) :: shj.right :: Nil) + } else if (!shj.right.isInstanceOf[Exchange]) { + shj.withNewChildren( + shj.left :: withShuffleExec(shj.requiredChildDistribution(1), shj.right) :: Nil) + } else { + shj + } + } + + private def checkAndInsertShuffle( + distribution: Distribution, + child: SparkPlan): SparkPlan = child match { + case SortExec(_, _, _: Exchange, _) => + child + case SortExec(_, _, _: QueryStageExec, _) => + child + case sort @ SortExec(_, _, agg: BaseAggregateExec, _) => + sort.withNewChildren(withShuffleExec(distribution, agg) :: Nil) + case _ => + withShuffleExec(distribution, child) + } + + private def withShuffleExec(distribution: Distribution, child: SparkPlan): SparkPlan = { + val numPartitions = distribution.requiredNumPartitions + .getOrElse(conf.numShufflePartitions) + ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiEnsureRequirements.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiEnsureRequirements.scala new file mode 100644 index 00000000000..586cad838b4 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiEnsureRequirements.scala @@ -0,0 +1,464 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.execution.{SortExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.exchange._ +import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.internal.SQLConf + +/** + * Copy from Apache Spark `EnsureRequirements` + * 1. remove reorder join predicates + * 2. remove shuffle pruning + */ +object KyuubiEnsureRequirements extends Rule[SparkPlan] { + + private def ensureDistributionAndOrdering( + parent: Option[SparkPlan], + originalChildren: Seq[SparkPlan], + requiredChildDistributions: Seq[Distribution], + requiredChildOrderings: Seq[Seq[SortOrder]], + shuffleOrigin: ShuffleOrigin): Seq[SparkPlan] = { + assert(requiredChildDistributions.length == originalChildren.length) + assert(requiredChildOrderings.length == originalChildren.length) + // Ensure that the operator's children satisfy their output distribution requirements. + var children = originalChildren.zip(requiredChildDistributions).map { + case (child, distribution) if child.outputPartitioning.satisfies(distribution) => + child + case (child, BroadcastDistribution(mode)) => + BroadcastExchangeExec(mode, child) + case (child, distribution) => + val numPartitions = distribution.requiredNumPartitions + .getOrElse(conf.numShufflePartitions) + ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child, shuffleOrigin) + } + + // Get the indexes of children which have specified distribution requirements and need to be + // co-partitioned. + val childrenIndexes = requiredChildDistributions.zipWithIndex.filter { + case (_: ClusteredDistribution, _) => true + case _ => false + }.map(_._2) + + // Special case: if all sides of the join are single partition and it's physical size less than + // or equal spark.sql.maxSinglePartitionBytes. + val preferSinglePartition = childrenIndexes.forall { i => + children(i).outputPartitioning == SinglePartition && + children(i).logicalLink + .forall(_.stats.sizeInBytes <= conf.getConf(SQLConf.MAX_SINGLE_PARTITION_BYTES)) + } + + // If there are more than one children, we'll need to check partitioning & distribution of them + // and see if extra shuffles are necessary. + if (childrenIndexes.length > 1 && !preferSinglePartition) { + val specs = childrenIndexes.map(i => { + val requiredDist = requiredChildDistributions(i) + assert( + requiredDist.isInstanceOf[ClusteredDistribution], + s"Expected ClusteredDistribution but found ${requiredDist.getClass.getSimpleName}") + i -> children(i).outputPartitioning.createShuffleSpec( + requiredDist.asInstanceOf[ClusteredDistribution]) + }).toMap + + // Find out the shuffle spec that gives better parallelism. Currently this is done by + // picking the spec with the largest number of partitions. + // + // NOTE: this is not optimal for the case when there are more than 2 children. Consider: + // (10, 10, 11) + // where the number represent the number of partitions for each child, it's better to pick 10 + // here since we only need to shuffle one side - we'd need to shuffle two sides if we pick 11. + // + // However this should be sufficient for now since in Spark nodes with multiple children + // always have exactly 2 children. + + // Whether we should consider `spark.sql.shuffle.partitions` and ensure enough parallelism + // during shuffle. To achieve a good trade-off between parallelism and shuffle cost, we only + // consider the minimum parallelism iff ALL children need to be re-shuffled. + // + // A child needs to be re-shuffled iff either one of below is true: + // 1. It can't create partitioning by itself, i.e., `canCreatePartitioning` returns false + // (as for the case of `RangePartitioning`), therefore it needs to be re-shuffled + // according to other shuffle spec. + // 2. It already has `ShuffleExchangeLike`, so we can re-use existing shuffle without + // introducing extra shuffle. + // + // On the other hand, in scenarios such as: + // HashPartitioning(5) <-> HashPartitioning(6) + // while `spark.sql.shuffle.partitions` is 10, we'll only re-shuffle the left side and make it + // HashPartitioning(6). + val shouldConsiderMinParallelism = specs.forall(p => + !p._2.canCreatePartitioning || children(p._1).isInstanceOf[ShuffleExchangeLike]) + // Choose all the specs that can be used to shuffle other children + val candidateSpecs = specs + .filter(_._2.canCreatePartitioning) + .filter(p => + !shouldConsiderMinParallelism || + children(p._1).outputPartitioning.numPartitions >= conf.defaultNumShufflePartitions) + val bestSpecOpt = if (candidateSpecs.isEmpty) { + None + } else { + // When choosing specs, we should consider those children with no `ShuffleExchangeLike` node + // first. For instance, if we have: + // A: (No_Exchange, 100) <---> B: (Exchange, 120) + // it's better to pick A and change B to (Exchange, 100) instead of picking B and insert a + // new shuffle for A. + val candidateSpecsWithoutShuffle = candidateSpecs.filter { case (k, _) => + !children(k).isInstanceOf[ShuffleExchangeLike] + } + val finalCandidateSpecs = if (candidateSpecsWithoutShuffle.nonEmpty) { + candidateSpecsWithoutShuffle + } else { + candidateSpecs + } + // Pick the spec with the best parallelism + Some(finalCandidateSpecs.values.maxBy(_.numPartitions)) + } + + // Check if the following conditions are satisfied: + // 1. There are exactly two children (e.g., join). Note that Spark doesn't support + // multi-way join at the moment, so this check should be sufficient. + // 2. All children are of `KeyGroupedPartitioning`, and they are compatible with each other + // If both are true, skip shuffle. + val isKeyGroupCompatible = parent.isDefined && + children.length == 2 && childrenIndexes.length == 2 && { + val left = children.head + val right = children(1) + val newChildren = checkKeyGroupCompatible( + parent.get, + left, + right, + requiredChildDistributions) + if (newChildren.isDefined) { + children = newChildren.get + } + newChildren.isDefined + } + + children = children.zip(requiredChildDistributions).zipWithIndex.map { + case ((child, _), idx) if isKeyGroupCompatible || !childrenIndexes.contains(idx) => + child + case ((child, dist), idx) => + if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) { + child + } else { + val newPartitioning = bestSpecOpt.map { bestSpec => + // Use the best spec to create a new partitioning to re-shuffle this child + val clustering = dist.asInstanceOf[ClusteredDistribution].clustering + bestSpec.createPartitioning(clustering) + }.getOrElse { + // No best spec available, so we create default partitioning from the required + // distribution + val numPartitions = dist.requiredNumPartitions + .getOrElse(conf.numShufflePartitions) + dist.createPartitioning(numPartitions) + } + + child match { + case ShuffleExchangeExec(_, c, so, ps) => + ShuffleExchangeExec(newPartitioning, c, so, ps) + case _ => ShuffleExchangeExec(newPartitioning, child) + } + } + } + } + + // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: + children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => + // If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort. + if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) { + child + } else { + SortExec(requiredOrdering, global = false, child = child) + } + } + + children + } + + /** + * Checks whether two children, `left` and `right`, of a join operator have compatible + * `KeyGroupedPartitioning`, and can benefit from storage-partitioned join. + * + * Returns the updated new children if the check is successful, otherwise `None`. + */ + private def checkKeyGroupCompatible( + parent: SparkPlan, + left: SparkPlan, + right: SparkPlan, + requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = { + parent match { + case smj: SortMergeJoinExec => + checkKeyGroupCompatible(left, right, smj.joinType, requiredChildDistribution) + case sj: ShuffledHashJoinExec => + checkKeyGroupCompatible(left, right, sj.joinType, requiredChildDistribution) + case _ => + None + } + } + + private def checkKeyGroupCompatible( + left: SparkPlan, + right: SparkPlan, + joinType: JoinType, + requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = { + assert(requiredChildDistribution.length == 2) + + var newLeft = left + var newRight = right + + val specs = Seq(left, right).zip(requiredChildDistribution).map { case (p, d) => + if (!d.isInstanceOf[ClusteredDistribution]) return None + val cd = d.asInstanceOf[ClusteredDistribution] + val specOpt = createKeyGroupedShuffleSpec(p.outputPartitioning, cd) + if (specOpt.isEmpty) return None + specOpt.get + } + + val leftSpec = specs.head + val rightSpec = specs(1) + + var isCompatible = false + if (!conf.v2BucketingPushPartValuesEnabled) { + isCompatible = leftSpec.isCompatibleWith(rightSpec) + } else { + logInfo("Pushing common partition values for storage-partitioned join") + isCompatible = leftSpec.areKeysCompatible(rightSpec) + + // Partition expressions are compatible. Regardless of whether partition values + // match from both sides of children, we can calculate a superset of partition values and + // push-down to respective data sources so they can adjust their output partitioning by + // filling missing partition keys with empty partitions. As result, we can still avoid + // shuffle. + // + // For instance, if two sides of a join have partition expressions + // `day(a)` and `day(b)` respectively + // (the join query could be `SELECT ... FROM t1 JOIN t2 on t1.a = t2.b`), but + // with different partition values: + // `day(a)`: [0, 1] + // `day(b)`: [1, 2, 3] + // Following the case 2 above, we don't have to shuffle both sides, but instead can + // just push the common set of partition values: `[0, 1, 2, 3]` down to the two data + // sources. + if (isCompatible) { + val leftPartValues = leftSpec.partitioning.partitionValues + val rightPartValues = rightSpec.partitioning.partitionValues + + logInfo( + s""" + |Left side # of partitions: ${leftPartValues.size} + |Right side # of partitions: ${rightPartValues.size} + |""".stripMargin) + + // As partition keys are compatible, we can pick either left or right as partition + // expressions + val partitionExprs = leftSpec.partitioning.expressions + + var mergedPartValues = InternalRowComparableWrapper + .mergePartitions(leftSpec.partitioning, rightSpec.partitioning, partitionExprs) + .map(v => (v, 1)) + + logInfo(s"After merging, there are ${mergedPartValues.size} partitions") + + var replicateLeftSide = false + var replicateRightSide = false + var applyPartialClustering = false + + // This means we allow partitions that are not clustered on their values, + // that is, multiple partitions with the same partition value. In the + // following, we calculate how many partitions that each distinct partition + // value has, and pushdown the information to scans, so they can adjust their + // final input partitions respectively. + if (conf.v2BucketingPartiallyClusteredDistributionEnabled) { + logInfo("Calculating partially clustered distribution for " + + "storage-partitioned join") + + // Similar to `OptimizeSkewedJoin`, we need to check join type and decide + // whether partially clustered distribution can be applied. For instance, the + // optimization cannot be applied to a left outer join, where the left hand + // side is chosen as the side to replicate partitions according to stats. + // Otherwise, query result could be incorrect. + val canReplicateLeft = canReplicateLeftSide(joinType) + val canReplicateRight = canReplicateRightSide(joinType) + + if (!canReplicateLeft && !canReplicateRight) { + logInfo("Skipping partially clustered distribution as it cannot be applied for " + + s"join type '$joinType'") + } else { + val leftLink = left.logicalLink + val rightLink = right.logicalLink + + replicateLeftSide = + if (leftLink.isDefined && rightLink.isDefined && + leftLink.get.stats.sizeInBytes > 1 && + rightLink.get.stats.sizeInBytes > 1) { + logInfo( + s""" + |Using plan statistics to determine which side of join to fully + |cluster partition values: + |Left side size (in bytes): ${leftLink.get.stats.sizeInBytes} + |Right side size (in bytes): ${rightLink.get.stats.sizeInBytes} + |""".stripMargin) + leftLink.get.stats.sizeInBytes < rightLink.get.stats.sizeInBytes + } else { + // As a simple heuristic, we pick the side with fewer number of partitions + // to apply the grouping & replication of partitions + logInfo("Using number of partitions to determine which side of join " + + "to fully cluster partition values") + leftPartValues.size < rightPartValues.size + } + + replicateRightSide = !replicateLeftSide + + // Similar to skewed join, we need to check the join type to see whether replication + // of partitions can be applied. For instance, replication should not be allowed for + // the left-hand side of a right outer join. + if (replicateLeftSide && !canReplicateLeft) { + logInfo("Left-hand side is picked but cannot be applied to join type " + + s"'$joinType'. Skipping partially clustered distribution.") + replicateLeftSide = false + } else if (replicateRightSide && !canReplicateRight) { + logInfo("Right-hand side is picked but cannot be applied to join type " + + s"'$joinType'. Skipping partially clustered distribution.") + replicateRightSide = false + } else { + val partValues = if (replicateLeftSide) rightPartValues else leftPartValues + val numExpectedPartitions = partValues + .map(InternalRowComparableWrapper(_, partitionExprs)) + .groupBy(identity) + .mapValues(_.size) + + mergedPartValues = mergedPartValues.map { case (partVal, numParts) => + ( + partVal, + numExpectedPartitions.getOrElse( + InternalRowComparableWrapper(partVal, partitionExprs), + numParts)) + } + + logInfo("After applying partially clustered distribution, there are " + + s"${mergedPartValues.map(_._2).sum} partitions.") + applyPartialClustering = true + } + } + } + + // Now we need to push-down the common partition key to the scan in each child + newLeft = populatePartitionValues( + left, + mergedPartValues, + applyPartialClustering, + replicateLeftSide) + newRight = populatePartitionValues( + right, + mergedPartValues, + applyPartialClustering, + replicateRightSide) + } + } + + if (isCompatible) Some(Seq(newLeft, newRight)) else None + } + + // Similar to `OptimizeSkewedJoin.canSplitRightSide` + private def canReplicateLeftSide(joinType: JoinType): Boolean = { + joinType == Inner || joinType == Cross || joinType == RightOuter + } + + // Similar to `OptimizeSkewedJoin.canSplitLeftSide` + private def canReplicateRightSide(joinType: JoinType): Boolean = { + joinType == Inner || joinType == Cross || joinType == LeftSemi || + joinType == LeftAnti || joinType == LeftOuter + } + + // Populate the common partition values down to the scan nodes + private def populatePartitionValues( + plan: SparkPlan, + values: Seq[(InternalRow, Int)], + applyPartialClustering: Boolean, + replicatePartitions: Boolean): SparkPlan = plan match { + case scan: BatchScanExec => + scan.copy(spjParams = scan.spjParams.copy( + commonPartitionValues = Some(values), + applyPartialClustering = applyPartialClustering, + replicatePartitions = replicatePartitions)) + case node => + node.mapChildren(child => + populatePartitionValues( + child, + values, + applyPartialClustering, + replicatePartitions)) + } + + /** + * Tries to create a [[KeyGroupedShuffleSpec]] from the input partitioning and distribution, if + * the partitioning is a [[KeyGroupedPartitioning]] (either directly or indirectly), and + * satisfies the given distribution. + */ + private def createKeyGroupedShuffleSpec( + partitioning: Partitioning, + distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = { + def check(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { + val attributes = partitioning.expressions.flatMap(_.collectLeaves()) + val clustering = distribution.clustering + + val satisfies = if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) { + attributes.length == clustering.length && attributes.zip(clustering).forall { + case (l, r) => l.semanticEquals(r) + } + } else { + partitioning.satisfies(distribution) + } + + if (satisfies) { + Some(partitioning.createShuffleSpec(distribution).asInstanceOf[KeyGroupedShuffleSpec]) + } else { + None + } + } + + partitioning match { + case p: KeyGroupedPartitioning => check(p) + case PartitioningCollection(partitionings) => + val specs = partitionings.map(p => createKeyGroupedShuffleSpec(p, distribution)) + assert(specs.forall(_.isEmpty) || specs.forall(_.isDefined)) + specs.head + case _ => None + } + } + + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case operator: SparkPlan => + val newChildren = ensureDistributionAndOrdering( + Some(operator), + operator.children, + operator.requiredChildDistribution, + operator.requiredChildOrdering, + ENSURE_REQUIREMENTS) + operator.withNewChildren(newChildren) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiQueryStagePreparation.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiQueryStagePreparation.scala new file mode 100644 index 00000000000..a7fcbecd422 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiQueryStagePreparation.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.QueryStageExec +import org.apache.spark.sql.execution.command.{ResetCommand, SetCommand} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec, ShuffleExchangeLike} +import org.apache.spark.sql.internal.SQLConf + +import org.apache.kyuubi.sql.KyuubiSQLConf._ + +/** + * This rule split stage into two parts: + * 1. previous stage + * 2. final stage + * For final stage, we can inject extra config. It's useful if we use repartition to optimize + * small files that needs bigger shuffle partition size than previous. + * + * Let's say we have a query with 3 stages, then the logical machine like: + * + * Set/Reset Command -> cleanup previousStage config if user set the spark config. + * Query -> AQE -> stage1 -> preparation (use previousStage to overwrite spark config) + * -> AQE -> stage2 -> preparation (use spark config) + * -> AQE -> stage3 -> preparation (use finalStage config to overwrite spark config, + * store spark config to previousStage.) + * + * An example of the new finalStage config: + * `spark.sql.adaptive.advisoryPartitionSizeInBytes` -> + * `spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes` + */ +case class FinalStageConfigIsolation(session: SparkSession) extends Rule[SparkPlan] { + import FinalStageConfigIsolation._ + + override def apply(plan: SparkPlan): SparkPlan = { + // this rule has no meaning without AQE + if (!conf.getConf(FINAL_STAGE_CONFIG_ISOLATION) || + !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED)) { + return plan + } + + if (isFinalStage(plan)) { + // We can not get the whole plan at query preparation phase to detect if current plan is + // for writing, so we depend on a tag which is been injected at post resolution phase. + // Note: we should still do clean up previous config for non-final stage to avoid such case: + // the first statement is write, but the second statement is query. + if (conf.getConf(FINAL_STAGE_CONFIG_ISOLATION_WRITE_ONLY) && + !WriteUtils.isWrite(session, plan)) { + return plan + } + + // set config for final stage + session.conf.getAll.filter(_._1.startsWith(FINAL_STAGE_CONFIG_PREFIX)).foreach { + case (k, v) => + val sparkConfigKey = s"spark.sql.${k.substring(FINAL_STAGE_CONFIG_PREFIX.length)}" + val previousStageConfigKey = + s"$PREVIOUS_STAGE_CONFIG_PREFIX${k.substring(FINAL_STAGE_CONFIG_PREFIX.length)}" + // store the previous config only if we have not stored, to avoid some query only + // have one stage that will overwrite real config. + if (!session.sessionState.conf.contains(previousStageConfigKey)) { + val originalValue = + if (session.conf.getOption(sparkConfigKey).isDefined) { + session.sessionState.conf.getConfString(sparkConfigKey) + } else { + // the default value of config is None, so we need to use a internal tag + INTERNAL_UNSET_CONFIG_TAG + } + logInfo(s"Store config: $sparkConfigKey to previousStage, " + + s"original value: $originalValue ") + session.sessionState.conf.setConfString(previousStageConfigKey, originalValue) + } + logInfo(s"For final stage: set $sparkConfigKey = $v.") + session.conf.set(sparkConfigKey, v) + } + } else { + // reset config for previous stage + session.conf.getAll.filter(_._1.startsWith(PREVIOUS_STAGE_CONFIG_PREFIX)).foreach { + case (k, v) => + val sparkConfigKey = s"spark.sql.${k.substring(PREVIOUS_STAGE_CONFIG_PREFIX.length)}" + logInfo(s"For previous stage: set $sparkConfigKey = $v.") + if (v == INTERNAL_UNSET_CONFIG_TAG) { + session.conf.unset(sparkConfigKey) + } else { + session.conf.set(sparkConfigKey, v) + } + // unset config so that we do not need to reset configs for every previous stage + session.conf.unset(k) + } + } + + plan + } + + /** + * Currently formula depend on AQE in Spark 3.1.1, not sure it can work in future. + */ + private def isFinalStage(plan: SparkPlan): Boolean = { + var shuffleNum = 0 + var broadcastNum = 0 + var reusedNum = 0 + var queryStageNum = 0 + + def collectNumber(p: SparkPlan): SparkPlan = { + p transform { + case shuffle: ShuffleExchangeLike => + shuffleNum += 1 + shuffle + + case broadcast: BroadcastExchangeLike => + broadcastNum += 1 + broadcast + + case reusedExchangeExec: ReusedExchangeExec => + reusedNum += 1 + reusedExchangeExec + + // query stage is leaf node so we need to transform it manually + // compatible with Spark 3.5: + // SPARK-42101: table cache is a independent query stage, so do not need include it. + case queryStage: QueryStageExec if queryStage.nodeName != "TableCacheQueryStage" => + queryStageNum += 1 + collectNumber(queryStage.plan) + queryStage + } + } + collectNumber(plan) + + if (shuffleNum == 0) { + // we don not care about broadcast stage here since it won't change partition number. + true + } else if (shuffleNum + broadcastNum + reusedNum == queryStageNum) { + true + } else { + false + } + } +} +object FinalStageConfigIsolation { + final val SQL_PREFIX = "spark.sql." + final val FINAL_STAGE_CONFIG_PREFIX = "spark.sql.finalStage." + final val PREVIOUS_STAGE_CONFIG_PREFIX = "spark.sql.previousStage." + final val INTERNAL_UNSET_CONFIG_TAG = "__INTERNAL_UNSET_CONFIG_TAG__" + + def getPreviousStageConfigKey(configKey: String): Option[String] = { + if (configKey.startsWith(SQL_PREFIX)) { + Some(s"$PREVIOUS_STAGE_CONFIG_PREFIX${configKey.substring(SQL_PREFIX.length)}") + } else { + None + } + } +} + +case class FinalStageConfigIsolationCleanRule(session: SparkSession) extends Rule[LogicalPlan] { + import FinalStageConfigIsolation._ + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + case set @ SetCommand(Some((k, Some(_)))) if k.startsWith(SQL_PREFIX) => + checkAndUnsetPreviousStageConfig(k) + set + + case reset @ ResetCommand(Some(k)) if k.startsWith(SQL_PREFIX) => + checkAndUnsetPreviousStageConfig(k) + reset + + case other => other + } + + private def checkAndUnsetPreviousStageConfig(configKey: String): Unit = { + getPreviousStageConfigKey(configKey).foreach { previousStageConfigKey => + if (session.sessionState.conf.contains(previousStageConfigKey)) { + logInfo(s"For previous stage: unset $previousStageConfigKey") + session.conf.unset(previousStageConfigKey) + } + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala new file mode 100644 index 00000000000..6f45dae126e --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import org.apache.spark.network.util.ByteUnit +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf._ + +object KyuubiSQLConf { + + val INSERT_REPARTITION_BEFORE_WRITE = + buildConf("spark.sql.optimizer.insertRepartitionBeforeWrite.enabled") + .doc("Add repartition node at the top of query plan. An approach of merging small files.") + .version("1.2.0") + .booleanConf + .createWithDefault(true) + + val INSERT_REPARTITION_NUM = + buildConf("spark.sql.optimizer.insertRepartitionNum") + .doc(s"The partition number if ${INSERT_REPARTITION_BEFORE_WRITE.key} is enabled. " + + s"If AQE is disabled, the default value is ${SQLConf.SHUFFLE_PARTITIONS.key}. " + + "If AQE is enabled, the default value is none that means depend on AQE. " + + "This config is used for Spark 3.1 only.") + .version("1.2.0") + .intConf + .createOptional + + val DYNAMIC_PARTITION_INSERTION_REPARTITION_NUM = + buildConf("spark.sql.optimizer.dynamicPartitionInsertionRepartitionNum") + .doc(s"The partition number of each dynamic partition if " + + s"${INSERT_REPARTITION_BEFORE_WRITE.key} is enabled. " + + "We will repartition by dynamic partition columns to reduce the small file but that " + + "can cause data skew. This config is to extend the partition of dynamic " + + "partition column to avoid skew but may generate some small files.") + .version("1.2.0") + .intConf + .createWithDefault(100) + + val FORCE_SHUFFLE_BEFORE_JOIN = + buildConf("spark.sql.optimizer.forceShuffleBeforeJoin.enabled") + .doc("Ensure shuffle node exists before shuffled join (shj and smj) to make AQE " + + "`OptimizeSkewedJoin` works (complex scenario join, multi table join).") + .version("1.2.0") + .booleanConf + .createWithDefault(false) + + val FINAL_STAGE_CONFIG_ISOLATION = + buildConf("spark.sql.optimizer.finalStageConfigIsolation.enabled") + .doc("If true, the final stage support use different config with previous stage. " + + "The prefix of final stage config key should be `spark.sql.finalStage.`." + + "For example, the raw spark config: `spark.sql.adaptive.advisoryPartitionSizeInBytes`, " + + "then the final stage config should be: " + + "`spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes`.") + .version("1.2.0") + .booleanConf + .createWithDefault(false) + + val SQL_CLASSIFICATION = "spark.sql.analyzer.classification" + val SQL_CLASSIFICATION_ENABLED = + buildConf("spark.sql.analyzer.classification.enabled") + .doc("When true, allows Kyuubi engine to judge this SQL's classification " + + s"and set `$SQL_CLASSIFICATION` back into sessionConf. " + + "Through this configuration item, Spark can optimizing configuration dynamic") + .version("1.4.0") + .booleanConf + .createWithDefault(false) + + val INSERT_ZORDER_BEFORE_WRITING = + buildConf("spark.sql.optimizer.insertZorderBeforeWriting.enabled") + .doc("When true, we will follow target table properties to insert zorder or not. " + + "The key properties are: 1) kyuubi.zorder.enabled; if this property is true, we will " + + "insert zorder before writing data. 2) kyuubi.zorder.cols; string split by comma, we " + + "will zorder by these cols.") + .version("1.4.0") + .booleanConf + .createWithDefault(true) + + val ZORDER_GLOBAL_SORT_ENABLED = + buildConf("spark.sql.optimizer.zorderGlobalSort.enabled") + .doc("When true, we do a global sort using zorder. Note that, it can cause data skew " + + "issue if the zorder columns have less cardinality. When false, we only do local sort " + + "using zorder.") + .version("1.4.0") + .booleanConf + .createWithDefault(true) + + val REBALANCE_BEFORE_ZORDER = + buildConf("spark.sql.optimizer.rebalanceBeforeZorder.enabled") + .doc("When true, we do a rebalance before zorder in case data skew. " + + "Note that, if the insertion is dynamic partition we will use the partition " + + "columns to rebalance. Note that, this config only affects with Spark 3.3.x") + .version("1.6.0") + .booleanConf + .createWithDefault(false) + + val REBALANCE_ZORDER_COLUMNS_ENABLED = + buildConf("spark.sql.optimizer.rebalanceZorderColumns.enabled") + .doc(s"When true and ${REBALANCE_BEFORE_ZORDER.key} is true, we do rebalance before " + + s"Z-Order. If it's dynamic partition insert, the rebalance expression will include " + + s"both partition columns and Z-Order columns. Note that, this config only " + + s"affects with Spark 3.3.x") + .version("1.6.0") + .booleanConf + .createWithDefault(false) + + val TWO_PHASE_REBALANCE_BEFORE_ZORDER = + buildConf("spark.sql.optimizer.twoPhaseRebalanceBeforeZorder.enabled") + .doc(s"When true and ${REBALANCE_BEFORE_ZORDER.key} is true, we do two phase rebalance " + + s"before Z-Order for the dynamic partition write. The first phase rebalance using " + + s"dynamic partition column; The second phase rebalance using dynamic partition column + " + + s"Z-Order columns. Note that, this config only affects with Spark 3.3.x") + .version("1.6.0") + .booleanConf + .createWithDefault(false) + + val ZORDER_USING_ORIGINAL_ORDERING_ENABLED = + buildConf("spark.sql.optimizer.zorderUsingOriginalOrdering.enabled") + .doc(s"When true and ${REBALANCE_BEFORE_ZORDER.key} is true, we do sort by " + + s"the original ordering i.e. lexicographical order. Note that, this config only " + + s"affects with Spark 3.3.x") + .version("1.6.0") + .booleanConf + .createWithDefault(false) + + val WATCHDOG_MAX_PARTITIONS = + buildConf("spark.sql.watchdog.maxPartitions") + .doc("Set the max partition number when spark scans a data source. " + + "Enable maxPartitions Strategy by specifying this configuration. " + + "Add maxPartitions Strategy to avoid scan excessive partitions " + + "on partitioned table, it's optional that works with defined") + .version("1.4.0") + .intConf + .createOptional + + val WATCHDOG_MAX_FILE_SIZE = + buildConf("spark.sql.watchdog.maxFileSize") + .doc("Set the maximum size in bytes of files when spark scans a data source. " + + "Enable maxFileSize Strategy by specifying this configuration. " + + "Add maxFileSize Strategy to avoid scan excessive size of files," + + " it's optional that works with defined") + .version("1.8.0") + .bytesConf(ByteUnit.BYTE) + .createOptional + + val WATCHDOG_FORCED_MAXOUTPUTROWS = + buildConf("spark.sql.watchdog.forcedMaxOutputRows") + .doc("Add ForcedMaxOutputRows rule to avoid huge output rows of non-limit query " + + "unexpectedly, it's optional that works with defined") + .version("1.4.0") + .intConf + .createOptional + + val DROP_IGNORE_NONEXISTENT = + buildConf("spark.sql.optimizer.dropIgnoreNonExistent") + .doc("Do not report an error if DROP DATABASE/TABLE/VIEW/FUNCTION/PARTITION specifies " + + "a non-existent database/table/view/function/partition") + .version("1.5.0") + .booleanConf + .createWithDefault(false) + + val INFER_REBALANCE_AND_SORT_ORDERS = + buildConf("spark.sql.optimizer.inferRebalanceAndSortOrders.enabled") + .doc("When ture, infer columns for rebalance and sort orders from original query, " + + "e.g. the join keys from join. It can avoid compression ratio regression.") + .version("1.7.0") + .booleanConf + .createWithDefault(false) + + val INFER_REBALANCE_AND_SORT_ORDERS_MAX_COLUMNS = + buildConf("spark.sql.optimizer.inferRebalanceAndSortOrdersMaxColumns") + .doc("The max columns of inferred columns.") + .version("1.7.0") + .intConf + .checkValue(_ > 0, "must be positive number") + .createWithDefault(3) + + val INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE = + buildConf("spark.sql.optimizer.insertRepartitionBeforeWriteIfNoShuffle.enabled") + .doc("When true, add repartition even if the original plan does not have shuffle.") + .version("1.7.0") + .booleanConf + .createWithDefault(false) + + val FINAL_STAGE_CONFIG_ISOLATION_WRITE_ONLY = + buildConf("spark.sql.optimizer.finalStageConfigIsolationWriteOnly.enabled") + .doc("When true, only enable final stage isolation for writing.") + .version("1.7.0") + .booleanConf + .createWithDefault(true) + + val FINAL_WRITE_STAGE_EAGERLY_KILL_EXECUTORS_ENABLED = + buildConf("spark.sql.finalWriteStage.eagerlyKillExecutors.enabled") + .doc("When true, eagerly kill redundant executors before running final write stage.") + .version("1.8.0") + .booleanConf + .createWithDefault(false) + + val FINAL_WRITE_STAGE_EAGERLY_KILL_EXECUTORS_KILL_ALL = + buildConf("spark.sql.finalWriteStage.eagerlyKillExecutors.killAll") + .doc("When true, eagerly kill all executors before running final write stage. " + + "Mainly for test.") + .version("1.8.0") + .booleanConf + .createWithDefault(false) + + val FINAL_WRITE_STAGE_SKIP_KILLING_EXECUTORS_FOR_TABLE_CACHE = + buildConf("spark.sql.finalWriteStage.skipKillingExecutorsForTableCache") + .doc("When true, skip killing executors if the plan has table caches.") + .version("1.8.0") + .booleanConf + .createWithDefault(true) + + val FINAL_WRITE_STAGE_PARTITION_FACTOR = + buildConf("spark.sql.finalWriteStage.retainExecutorsFactor") + .doc("If the target executors * factor < active executors, and " + + "target executors * factor > min executors, then kill redundant executors.") + .version("1.8.0") + .doubleConf + .checkValue(_ >= 1, "must be bigger than or equal to 1") + .createWithDefault(1.2) + + val FINAL_WRITE_STAGE_RESOURCE_ISOLATION_ENABLED = + buildConf("spark.sql.finalWriteStage.resourceIsolation.enabled") + .doc( + "When true, make final write stage resource isolation using custom RDD resource profile.") + .version("1.8.0") + .booleanConf + .createWithDefault(false) + + val FINAL_WRITE_STAGE_EXECUTOR_CORES = + buildConf("spark.sql.finalWriteStage.executorCores") + .doc("Specify the executor core request for final write stage. " + + "It would be passed to the RDD resource profile.") + .version("1.8.0") + .intConf + .createOptional + + val FINAL_WRITE_STAGE_EXECUTOR_MEMORY = + buildConf("spark.sql.finalWriteStage.executorMemory") + .doc("Specify the executor on heap memory request for final write stage. " + + "It would be passed to the RDD resource profile.") + .version("1.8.0") + .stringConf + .createOptional + + val FINAL_WRITE_STAGE_EXECUTOR_MEMORY_OVERHEAD = + buildConf("spark.sql.finalWriteStage.executorMemoryOverhead") + .doc("Specify the executor memory overhead request for final write stage. " + + "It would be passed to the RDD resource profile.") + .version("1.8.0") + .stringConf + .createOptional + + val FINAL_WRITE_STAGE_EXECUTOR_OFF_HEAP_MEMORY = + buildConf("spark.sql.finalWriteStage.executorOffHeapMemory") + .doc("Specify the executor off heap memory request for final write stage. " + + "It would be passed to the RDD resource profile.") + .version("1.8.0") + .stringConf + .createOptional +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLExtensionException.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLExtensionException.scala new file mode 100644 index 00000000000..88c5a988fd9 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLExtensionException.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import java.sql.SQLException + +class KyuubiSQLExtensionException(reason: String, cause: Throwable) + extends SQLException(reason, cause) { + + def this(reason: String) = { + this(reason, null) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLAstBuilder.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLAstBuilder.scala new file mode 100644 index 00000000000..cc00bf88e94 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLAstBuilder.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import scala.collection.JavaConverters.asScalaBufferConverter +import scala.collection.mutable.ListBuffer + +import org.antlr.v4.runtime.ParserRuleContext +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.ParseTree +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, Sort} + +import org.apache.kyuubi.sql.KyuubiSparkSQLParser._ +import org.apache.kyuubi.sql.zorder.{OptimizeZorderStatement, Zorder} + +class KyuubiSparkSQLAstBuilder extends KyuubiSparkSQLBaseVisitor[AnyRef] with SQLConfHelper { + + def buildOptimizeStatement( + unparsedPredicateOptimize: UnparsedPredicateOptimize, + parseExpression: String => Expression): LogicalPlan = { + + val UnparsedPredicateOptimize(tableIdent, tablePredicate, orderExpr) = + unparsedPredicateOptimize + + val predicate = tablePredicate.map(parseExpression) + verifyPartitionPredicates(predicate) + val table = UnresolvedRelation(tableIdent) + val tableWithFilter = predicate match { + case Some(expr) => Filter(expr, table) + case None => table + } + val query = + Sort( + SortOrder(orderExpr, Ascending, NullsLast, Seq.empty) :: Nil, + conf.getConf(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED), + Project(Seq(UnresolvedStar(None)), tableWithFilter)) + OptimizeZorderStatement(tableIdent, query) + } + + private def verifyPartitionPredicates(predicates: Option[Expression]): Unit = { + predicates.foreach { + case p if !isLikelySelective(p) => + throw new KyuubiSQLExtensionException(s"unsupported partition predicates: ${p.sql}") + case _ => + } + } + + /** + * Forked from Apache Spark's org.apache.spark.sql.catalyst.expressions.PredicateHelper + * The `PredicateHelper.isLikelySelective()` is available since Spark-3.3, forked for Spark + * that is lower than 3.3. + * + * Returns whether an expression is likely to be selective + */ + private def isLikelySelective(e: Expression): Boolean = e match { + case Not(expr) => isLikelySelective(expr) + case And(l, r) => isLikelySelective(l) || isLikelySelective(r) + case Or(l, r) => isLikelySelective(l) && isLikelySelective(r) + case _: StringRegexExpression => true + case _: BinaryComparison => true + case _: In | _: InSet => true + case _: StringPredicate => true + case BinaryPredicate(_) => true + case _: MultiLikeBase => true + case _ => false + } + + private object BinaryPredicate { + def unapply(expr: Expression): Option[Expression] = expr match { + case _: Contains => Option(expr) + case _: StartsWith => Option(expr) + case _: EndsWith => Option(expr) + case _ => None + } + } + + /** + * Create an expression from the given context. This method just passes the context on to the + * visitor and only takes care of typing (We assume that the visitor returns an Expression here). + */ + protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) + + protected def multiPart(ctx: ParserRuleContext): Seq[String] = typedVisit(ctx) + + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = { + visit(ctx.statement()).asInstanceOf[LogicalPlan] + } + + override def visitOptimizeZorder( + ctx: OptimizeZorderContext): UnparsedPredicateOptimize = withOrigin(ctx) { + val tableIdent = multiPart(ctx.multipartIdentifier()) + + val predicate = Option(ctx.whereClause()) + .map(_.partitionPredicate) + .map(extractRawText(_)) + + val zorderCols = ctx.zorderClause().order.asScala + .map(visitMultipartIdentifier) + .map(UnresolvedAttribute(_)) + .toSeq + + val orderExpr = + if (zorderCols.length == 1) { + zorderCols.head + } else { + Zorder(zorderCols) + } + UnparsedPredicateOptimize(tableIdent, predicate, orderExpr) + } + + override def visitPassThrough(ctx: PassThroughContext): LogicalPlan = null + + override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = + withOrigin(ctx) { + ctx.parts.asScala.map(_.getText).toSeq + } + + override def visitZorderClause(ctx: ZorderClauseContext): Seq[UnresolvedAttribute] = + withOrigin(ctx) { + val res = ListBuffer[UnresolvedAttribute]() + ctx.multipartIdentifier().forEach { identifier => + res += UnresolvedAttribute(identifier.parts.asScala.map(_.getText).toSeq) + } + res.toSeq + } + + private def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } + + private def extractRawText(exprContext: ParserRuleContext): String = { + // Extract the raw expression which will be parsed later + exprContext.getStart.getInputStream.getText(new Interval( + exprContext.getStart.getStartIndex, + exprContext.getStop.getStopIndex)) + } +} + +/** + * a logical plan contains an unparsed expression that will be parsed by spark. + */ +trait UnparsedExpressionLogicalPlan extends LogicalPlan { + override def output: Seq[Attribute] = throw new UnsupportedOperationException() + + override def children: Seq[LogicalPlan] = throw new UnsupportedOperationException() + + protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = + throw new UnsupportedOperationException() +} + +case class UnparsedPredicateOptimize( + tableIdent: Seq[String], + tablePredicate: Option[String], + orderExpr: Expression) extends UnparsedExpressionLogicalPlan {} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLCommonExtension.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLCommonExtension.scala new file mode 100644 index 00000000000..f39ad3cc390 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLCommonExtension.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import org.apache.spark.sql.SparkSessionExtensions + +import org.apache.kyuubi.sql.zorder.{InsertZorderBeforeWritingDatasource33, InsertZorderBeforeWritingHive33, ResolveZorder} + +class KyuubiSparkSQLCommonExtension extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + KyuubiSparkSQLCommonExtension.injectCommonExtensions(extensions) + } +} + +object KyuubiSparkSQLCommonExtension { + def injectCommonExtensions(extensions: SparkSessionExtensions): Unit = { + // inject zorder parser and related rules + extensions.injectParser { case (_, parser) => new SparkKyuubiSparkSQLParser(parser) } + extensions.injectResolutionRule(ResolveZorder) + + // Note that: + // InsertZorderBeforeWritingDatasource and InsertZorderBeforeWritingHive + // should be applied before + // RepartitionBeforeWriting and RebalanceBeforeWriting + // because we can only apply one of them (i.e. Global Sort or Repartition/Rebalance) + extensions.injectPostHocResolutionRule(InsertZorderBeforeWritingDatasource33) + extensions.injectPostHocResolutionRule(InsertZorderBeforeWritingHive33) + extensions.injectPostHocResolutionRule(FinalStageConfigIsolationCleanRule) + + extensions.injectQueryStagePrepRule(_ => InsertShuffleNodeBeforeJoin) + + extensions.injectQueryStagePrepRule(FinalStageConfigIsolation(_)) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala new file mode 100644 index 00000000000..792315d897a --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import org.apache.spark.sql.{FinalStageResourceManager, InjectCustomResourceProfile, SparkSessionExtensions} + +import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MaxScanStrategy} + +// scalastyle:off line.size.limit +/** + * Depend on Spark SQL Extension framework, we can use this extension follow steps + * 1. move this jar into $SPARK_HOME/jars + * 2. add config into `spark-defaults.conf`: `spark.sql.extensions=org.apache.kyuubi.sql.KyuubiSparkSQLExtension` + */ +// scalastyle:on line.size.limit +class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + KyuubiSparkSQLCommonExtension.injectCommonExtensions(extensions) + + extensions.injectPostHocResolutionRule(RebalanceBeforeWritingDatasource) + extensions.injectPostHocResolutionRule(RebalanceBeforeWritingHive) + extensions.injectPostHocResolutionRule(DropIgnoreNonexistent) + + // watchdog extension + extensions.injectOptimizerRule(ForcedMaxOutputRowsRule) + extensions.injectPlannerStrategy(MaxScanStrategy) + + extensions.injectQueryStagePrepRule(FinalStageResourceManager(_)) + extensions.injectQueryStagePrepRule(InjectCustomResourceProfile) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLParser.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLParser.scala new file mode 100644 index 00000000000..c4418c33c44 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLParser.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface, PostProcessor} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.types.{DataType, StructType} + +abstract class KyuubiSparkSQLParserBase extends ParserInterface with SQLConfHelper { + def delegate: ParserInterface + def astBuilder: KyuubiSparkSQLAstBuilder + + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => + astBuilder.visit(parser.singleStatement()) match { + case optimize: UnparsedPredicateOptimize => + astBuilder.buildOptimizeStatement(optimize, delegate.parseExpression) + case plan: LogicalPlan => plan + case _ => delegate.parsePlan(sqlText) + } + } + + protected def parse[T](command: String)(toResult: KyuubiSparkSQLParser => T): T = { + val lexer = new KyuubiSparkSQLLexer( + new UpperCaseCharStream(CharStreams.fromString(command))) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new KyuubiSparkSQLParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } catch { + case _: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.seek(0) // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } catch { + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException(Option(command), e.message, position, position) + } + } + + override def parseExpression(sqlText: String): Expression = { + delegate.parseExpression(sqlText) + } + + override def parseTableIdentifier(sqlText: String): TableIdentifier = { + delegate.parseTableIdentifier(sqlText) + } + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { + delegate.parseFunctionIdentifier(sqlText) + } + + override def parseMultipartIdentifier(sqlText: String): Seq[String] = { + delegate.parseMultipartIdentifier(sqlText) + } + + override def parseTableSchema(sqlText: String): StructType = { + delegate.parseTableSchema(sqlText) + } + + override def parseDataType(sqlText: String): DataType = { + delegate.parseDataType(sqlText) + } + + /** + * This functions was introduced since spark-3.3, for more details, please see + * https://github.com/apache/spark/pull/34543 + */ + override def parseQuery(sqlText: String): LogicalPlan = { + delegate.parseQuery(sqlText) + } +} + +class SparkKyuubiSparkSQLParser( + override val delegate: ParserInterface) + extends KyuubiSparkSQLParserBase { + def astBuilder: KyuubiSparkSQLAstBuilder = new KyuubiSparkSQLAstBuilder +} + +/* Copied from Apache Spark's to avoid dependency on Spark Internals */ +class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { + override def consume(): Unit = wrapped.consume() + override def getSourceName(): String = wrapped.getSourceName + override def index(): Int = wrapped.index + override def mark(): Int = wrapped.mark + override def release(marker: Int): Unit = wrapped.release(marker) + override def seek(where: Int): Unit = wrapped.seek(where) + override def size(): Int = wrapped.size + + override def getText(interval: Interval): String = wrapped.getText(interval) + + // scalastyle:off + override def LA(i: Int): Int = { + val la = wrapped.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } + // scalastyle:on +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/RebalanceBeforeWriting.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/RebalanceBeforeWriting.scala new file mode 100644 index 00000000000..3cbacdd2f03 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/RebalanceBeforeWriting.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ + +trait RepartitionBuilderWithRebalance extends RepartitionBuilder { + override def buildRepartition( + dynamicPartitionColumns: Seq[Attribute], + query: LogicalPlan): LogicalPlan = { + if (!conf.getConf(KyuubiSQLConf.INFER_REBALANCE_AND_SORT_ORDERS) || + dynamicPartitionColumns.nonEmpty) { + RebalancePartitions(dynamicPartitionColumns, query) + } else { + val maxColumns = conf.getConf(KyuubiSQLConf.INFER_REBALANCE_AND_SORT_ORDERS_MAX_COLUMNS) + val inferred = InferRebalanceAndSortOrders.infer(query) + if (inferred.isDefined) { + val (partitioning, ordering) = inferred.get + val rebalance = RebalancePartitions(partitioning.take(maxColumns), query) + if (ordering.nonEmpty) { + val sortOrders = ordering.take(maxColumns).map(o => SortOrder(o, Ascending)) + Sort(sortOrders, false, rebalance) + } else { + rebalance + } + } else { + RebalancePartitions(dynamicPartitionColumns, query) + } + } + } + + override def canInsertRepartitionByExpression(plan: LogicalPlan): Boolean = { + super.canInsertRepartitionByExpression(plan) && { + plan match { + case _: RebalancePartitions => false + case _ => true + } + } + } +} + +/** + * For datasource table, there two commands can write data to table + * 1. InsertIntoHadoopFsRelationCommand + * 2. CreateDataSourceTableAsSelectCommand + * This rule add a RebalancePartitions node between write and query + */ +case class RebalanceBeforeWritingDatasource(session: SparkSession) + extends RepartitionBeforeWritingDatasourceBase + with RepartitionBuilderWithRebalance {} + +/** + * For Hive table, there two commands can write data to table + * 1. InsertIntoHiveTable + * 2. CreateHiveTableAsSelectCommand + * This rule add a RebalancePartitions node between write and query + */ +case class RebalanceBeforeWritingHive(session: SparkSession) + extends RepartitionBeforeWritingHiveBase + with RepartitionBuilderWithRebalance {} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/RepartitionBeforeWritingBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/RepartitionBeforeWritingBase.scala new file mode 100644 index 00000000000..3ebb9740f5f --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/RepartitionBeforeWritingBase.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.hive.execution.InsertIntoHiveTable +import org.apache.spark.sql.internal.StaticSQLConf + +trait RepartitionBuilder extends Rule[LogicalPlan] with RepartitionBeforeWriteHelper { + def buildRepartition( + dynamicPartitionColumns: Seq[Attribute], + query: LogicalPlan): LogicalPlan +} + +/** + * For datasource table, there two commands can write data to table + * 1. InsertIntoHadoopFsRelationCommand + * 2. CreateDataSourceTableAsSelectCommand + * This rule add a repartition node between write and query + */ +abstract class RepartitionBeforeWritingDatasourceBase extends RepartitionBuilder { + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE)) { + addRepartition(plan) + } else { + plan + } + } + + private def addRepartition(plan: LogicalPlan): LogicalPlan = plan match { + case i @ InsertIntoHadoopFsRelationCommand(_, sp, _, pc, bucket, _, _, query, _, _, _, _) + if query.resolved && bucket.isEmpty && canInsertRepartitionByExpression(query) => + val dynamicPartitionColumns = pc.filterNot(attr => sp.contains(attr.name)) + i.copy(query = buildRepartition(dynamicPartitionColumns, query)) + + case u @ Union(children, _, _) => + u.copy(children = children.map(addRepartition)) + + case _ => plan + } +} + +/** + * For Hive table, there two commands can write data to table + * 1. InsertIntoHiveTable + * 2. CreateHiveTableAsSelectCommand + * This rule add a repartition node between write and query + */ +abstract class RepartitionBeforeWritingHiveBase extends RepartitionBuilder { + override def apply(plan: LogicalPlan): LogicalPlan = { + if (conf.getConf(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive" && + conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE)) { + addRepartition(plan) + } else { + plan + } + } + + def addRepartition(plan: LogicalPlan): LogicalPlan = plan match { + case i @ InsertIntoHiveTable(table, partition, query, _, _, _, _, _, _, _, _) + if query.resolved && table.bucketSpec.isEmpty && canInsertRepartitionByExpression(query) => + val dynamicPartitionColumns = partition.filter(_._2.isEmpty).keys + .flatMap(name => query.output.find(_.name == name)).toSeq + i.copy(query = buildRepartition(dynamicPartitionColumns, query)) + + case u @ Union(children, _, _) => + u.copy(children = children.map(addRepartition)) + + case _ => plan + } +} + +trait RepartitionBeforeWriteHelper extends Rule[LogicalPlan] { + private def hasBenefit(plan: LogicalPlan): Boolean = { + def probablyHasShuffle: Boolean = plan.find { + case _: Join => true + case _: Aggregate => true + case _: Distinct => true + case _: Deduplicate => true + case _: Window => true + case s: Sort if s.global => true + case _: RepartitionOperation => true + case _: GlobalLimit => true + case _ => false + }.isDefined + + conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE) || probablyHasShuffle + } + + def canInsertRepartitionByExpression(plan: LogicalPlan): Boolean = { + def canInsert(p: LogicalPlan): Boolean = p match { + case Project(_, child) => canInsert(child) + case SubqueryAlias(_, child) => canInsert(child) + case Limit(_, _) => false + case _: Sort => false + case _: RepartitionByExpression => false + case _: Repartition => false + case _ => true + } + + // 1. make sure AQE is enabled, otherwise it is no meaning to add a shuffle + // 2. make sure it does not break the semantics of original plan + // 3. try to avoid adding a shuffle if it has potential performance regression + conf.adaptiveExecutionEnabled && canInsert(plan) && hasBenefit(plan) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/WriteUtils.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/WriteUtils.scala new file mode 100644 index 00000000000..89dd8319480 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/WriteUtils.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.{SparkPlan, UnionExec} +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec + +object WriteUtils { + def isWrite(session: SparkSession, plan: SparkPlan): Boolean = { + plan match { + case _: DataWritingCommandExec => true + case _: V2TableWriteExec => true + case u: UnionExec if u.children.nonEmpty => u.children.forall(isWrite(session, _)) + case _ => false + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala new file mode 100644 index 00000000000..4f897d1b600 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.command.DataWritingCommand + +import org.apache.kyuubi.sql.KyuubiSQLConf + +/* + * Add ForcedMaxOutputRows rule for output rows limitation + * to avoid huge output rows of non_limit query unexpectedly + * mainly applied to cases as below: + * + * case 1: + * {{{ + * SELECT [c1, c2, ...] + * }}} + * + * case 2: + * {{{ + * WITH CTE AS ( + * ...) + * SELECT [c1, c2, ...] FROM CTE ... + * }}} + * + * The Logical Rule add a GlobalLimit node before root project + * */ +trait ForcedMaxOutputRowsBase extends Rule[LogicalPlan] { + + protected def isChildAggregate(a: Aggregate): Boolean + + protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match { + case Aggregate(_, Alias(_, "havingCondition") :: Nil, _) => false + case agg: Aggregate => !isChildAggregate(agg) + case _: RepartitionByExpression => true + case _: Distinct => true + case _: Filter => true + case _: Project => true + case Limit(_, _) => true + case _: Sort => true + case Union(children, _, _) => + if (children.exists(_.isInstanceOf[DataWritingCommand])) { + false + } else { + true + } + case _: MultiInstanceRelation => true + case _: Join => true + case _ => false + } + + protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = { + maxOutputRowsOpt match { + case Some(forcedMaxOutputRows) => canInsertLimitInner(p) && + !p.maxRows.exists(_ <= forcedMaxOutputRows) + case None => false + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + val maxOutputRowsOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) + plan match { + case p if p.resolved && canInsertLimit(p, maxOutputRowsOpt) => + Limit( + maxOutputRowsOpt.get, + plan) + case _ => plan + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala new file mode 100644 index 00000000000..a3d990b1098 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CommandResult, LogicalPlan, Union, WithCTE} +import org.apache.spark.sql.execution.command.DataWritingCommand + +case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase { + + override protected def isChildAggregate(a: Aggregate): Boolean = false + + override protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match { + case WithCTE(plan, _) => this.canInsertLimitInner(plan) + case plan: LogicalPlan => plan match { + case Union(children, _, _) => !children.exists { + case _: DataWritingCommand => true + case p: CommandResult if p.commandLogicalPlan.isInstanceOf[DataWritingCommand] => true + case _ => false + } + case _ => super.canInsertLimitInner(plan) + } + } + + override protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = { + p match { + case WithCTE(plan, _) => this.canInsertLimit(plan, maxOutputRowsOpt) + case _ => super.canInsertLimit(p, maxOutputRowsOpt) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala new file mode 100644 index 00000000000..e44309192a9 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/KyuubiWatchDogException.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.kyuubi.sql.KyuubiSQLExtensionException + +final class MaxPartitionExceedException( + private val reason: String = "", + private val cause: Throwable = None.orNull) + extends KyuubiSQLExtensionException(reason, cause) + +final class MaxFileSizeExceedException( + private val reason: String = "", + private val cause: Throwable = None.orNull) + extends KyuubiSQLExtensionException(reason, cause) diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxScanStrategy.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxScanStrategy.scala new file mode 100644 index 00000000000..1ed55ebc2fd --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/watchdog/MaxScanStrategy.scala @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.watchdog + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.{PruneFileSourcePartitionHelper, SparkSession, Strategy} +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, InMemoryFileIndex, LogicalRelation} +import org.apache.spark.sql.types.StructType + +import org.apache.kyuubi.sql.KyuubiSQLConf + +/** + * Add MaxScanStrategy to avoid scan excessive partitions or files + * 1. Check if scan exceed maxPartition of partitioned table + * 2. Check if scan exceed maxFileSize (calculated by hive table and partition statistics) + * This Strategy Add Planner Strategy after LogicalOptimizer + * @param session + */ +case class MaxScanStrategy(session: SparkSession) + extends Strategy + with SQLConfHelper + with PruneFileSourcePartitionHelper { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + val maxScanPartitionsOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS) + val maxFileSizeOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_MAX_FILE_SIZE) + if (maxScanPartitionsOpt.isDefined || maxFileSizeOpt.isDefined) { + checkScan(plan, maxScanPartitionsOpt, maxFileSizeOpt) + } + Nil + } + + private def checkScan( + plan: LogicalPlan, + maxScanPartitionsOpt: Option[Int], + maxFileSizeOpt: Option[Long]): Unit = { + plan match { + case ScanOperation(_, _, _, relation: HiveTableRelation) => + if (relation.isPartitioned) { + relation.prunedPartitions match { + case Some(prunedPartitions) => + if (maxScanPartitionsOpt.exists(_ < prunedPartitions.size)) { + throw new MaxPartitionExceedException( + s""" + |SQL job scan hive partition: ${prunedPartitions.size} + |exceed restrict of hive scan maxPartition ${maxScanPartitionsOpt.get} + |You should optimize your SQL logical according partition structure + |or shorten query scope such as p_date, detail as below: + |Table: ${relation.tableMeta.qualifiedName} + |Owner: ${relation.tableMeta.owner} + |Partition Structure: ${relation.partitionCols.map(_.name).mkString(", ")} + |""".stripMargin) + } + lazy val scanFileSize = prunedPartitions.flatMap(_.stats).map(_.sizeInBytes).sum + if (maxFileSizeOpt.exists(_ < scanFileSize)) { + throw partTableMaxFileExceedError( + scanFileSize, + maxFileSizeOpt.get, + Some(relation.tableMeta), + prunedPartitions.flatMap(_.storage.locationUri).map(_.toString), + relation.partitionCols.map(_.name)) + } + case _ => + lazy val scanPartitions: Int = session + .sessionState.catalog.externalCatalog.listPartitionNames( + relation.tableMeta.database, + relation.tableMeta.identifier.table).size + if (maxScanPartitionsOpt.exists(_ < scanPartitions)) { + throw new MaxPartitionExceedException( + s""" + |Your SQL job scan a whole huge table without any partition filter, + |You should optimize your SQL logical according partition structure + |or shorten query scope such as p_date, detail as below: + |Table: ${relation.tableMeta.qualifiedName} + |Owner: ${relation.tableMeta.owner} + |Partition Structure: ${relation.partitionCols.map(_.name).mkString(", ")} + |""".stripMargin) + } + + lazy val scanFileSize: BigInt = + relation.tableMeta.stats.map(_.sizeInBytes).getOrElse { + session + .sessionState.catalog.externalCatalog.listPartitions( + relation.tableMeta.database, + relation.tableMeta.identifier.table).flatMap(_.stats).map(_.sizeInBytes).sum + } + if (maxFileSizeOpt.exists(_ < scanFileSize)) { + throw new MaxFileSizeExceedException( + s""" + |Your SQL job scan a whole huge table without any partition filter, + |You should optimize your SQL logical according partition structure + |or shorten query scope such as p_date, detail as below: + |Table: ${relation.tableMeta.qualifiedName} + |Owner: ${relation.tableMeta.owner} + |Partition Structure: ${relation.partitionCols.map(_.name).mkString(", ")} + |""".stripMargin) + } + } + } else { + lazy val scanFileSize = relation.tableMeta.stats.map(_.sizeInBytes).sum + if (maxFileSizeOpt.exists(_ < scanFileSize)) { + throw nonPartTableMaxFileExceedError( + scanFileSize, + maxFileSizeOpt.get, + Some(relation.tableMeta)) + } + } + case ScanOperation( + _, + _, + filters, + relation @ LogicalRelation( + fsRelation @ HadoopFsRelation( + fileIndex: InMemoryFileIndex, + partitionSchema, + _, + _, + _, + _), + _, + _, + _)) => + if (fsRelation.partitionSchema.nonEmpty) { + val (partitionKeyFilters, dataFilter) = + getPartitionKeyFiltersAndDataFilters( + SparkSession.active, + relation, + partitionSchema, + filters, + relation.output) + val prunedPartitions = fileIndex.listFiles( + partitionKeyFilters.toSeq, + dataFilter) + if (maxScanPartitionsOpt.exists(_ < prunedPartitions.size)) { + throw maxPartitionExceedError( + prunedPartitions.size, + maxScanPartitionsOpt.get, + relation.catalogTable, + fileIndex.rootPaths, + fsRelation.partitionSchema) + } + lazy val scanFileSize = prunedPartitions.flatMap(_.files).map(_.getLen).sum + if (maxFileSizeOpt.exists(_ < scanFileSize)) { + throw partTableMaxFileExceedError( + scanFileSize, + maxFileSizeOpt.get, + relation.catalogTable, + fileIndex.rootPaths.map(_.toString), + fsRelation.partitionSchema.map(_.name)) + } + } else { + lazy val scanFileSize = fileIndex.sizeInBytes + if (maxFileSizeOpt.exists(_ < scanFileSize)) { + throw nonPartTableMaxFileExceedError( + scanFileSize, + maxFileSizeOpt.get, + relation.catalogTable) + } + } + case ScanOperation( + _, + _, + filters, + logicalRelation @ LogicalRelation( + fsRelation @ HadoopFsRelation( + catalogFileIndex: CatalogFileIndex, + partitionSchema, + _, + _, + _, + _), + _, + _, + _)) => + if (fsRelation.partitionSchema.nonEmpty) { + val (partitionKeyFilters, _) = + getPartitionKeyFiltersAndDataFilters( + SparkSession.active, + logicalRelation, + partitionSchema, + filters, + logicalRelation.output) + + val fileIndex = catalogFileIndex.filterPartitions( + partitionKeyFilters.toSeq) + + lazy val prunedPartitionSize = fileIndex.partitionSpec().partitions.size + if (maxScanPartitionsOpt.exists(_ < prunedPartitionSize)) { + throw maxPartitionExceedError( + prunedPartitionSize, + maxScanPartitionsOpt.get, + logicalRelation.catalogTable, + catalogFileIndex.rootPaths, + fsRelation.partitionSchema) + } + + lazy val scanFileSize = fileIndex + .listFiles(Nil, Nil).flatMap(_.files).map(_.getLen).sum + if (maxFileSizeOpt.exists(_ < scanFileSize)) { + throw partTableMaxFileExceedError( + scanFileSize, + maxFileSizeOpt.get, + logicalRelation.catalogTable, + catalogFileIndex.rootPaths.map(_.toString), + fsRelation.partitionSchema.map(_.name)) + } + } else { + lazy val scanFileSize = catalogFileIndex.sizeInBytes + if (maxFileSizeOpt.exists(_ < scanFileSize)) { + throw nonPartTableMaxFileExceedError( + scanFileSize, + maxFileSizeOpt.get, + logicalRelation.catalogTable) + } + } + case _ => + } + } + + def maxPartitionExceedError( + prunedPartitionSize: Int, + maxPartitionSize: Int, + tableMeta: Option[CatalogTable], + rootPaths: Seq[Path], + partitionSchema: StructType): Throwable = { + val truncatedPaths = + if (rootPaths.length > 5) { + rootPaths.slice(0, 5).mkString(",") + """... """ + (rootPaths.length - 5) + " more paths" + } else { + rootPaths.mkString(",") + } + + new MaxPartitionExceedException( + s""" + |SQL job scan data source partition: $prunedPartitionSize + |exceed restrict of data source scan maxPartition $maxPartitionSize + |You should optimize your SQL logical according partition structure + |or shorten query scope such as p_date, detail as below: + |Table: ${tableMeta.map(_.qualifiedName).getOrElse("")} + |Owner: ${tableMeta.map(_.owner).getOrElse("")} + |RootPaths: $truncatedPaths + |Partition Structure: ${partitionSchema.map(_.name).mkString(", ")} + |""".stripMargin) + } + + private def partTableMaxFileExceedError( + scanFileSize: Number, + maxFileSize: Long, + tableMeta: Option[CatalogTable], + rootPaths: Seq[String], + partitions: Seq[String]): Throwable = { + val truncatedPaths = + if (rootPaths.length > 5) { + rootPaths.slice(0, 5).mkString(",") + """... """ + (rootPaths.length - 5) + " more paths" + } else { + rootPaths.mkString(",") + } + + new MaxFileSizeExceedException( + s""" + |SQL job scan file size in bytes: $scanFileSize + |exceed restrict of table scan maxFileSize $maxFileSize + |You should optimize your SQL logical according partition structure + |or shorten query scope such as p_date, detail as below: + |Table: ${tableMeta.map(_.qualifiedName).getOrElse("")} + |Owner: ${tableMeta.map(_.owner).getOrElse("")} + |RootPaths: $truncatedPaths + |Partition Structure: ${partitions.mkString(", ")} + |""".stripMargin) + } + + private def nonPartTableMaxFileExceedError( + scanFileSize: Number, + maxFileSize: Long, + tableMeta: Option[CatalogTable]): Throwable = { + new MaxFileSizeExceedException( + s""" + |SQL job scan file size in bytes: $scanFileSize + |exceed restrict of table scan maxFileSize $maxFileSize + |detail as below: + |Table: ${tableMeta.map(_.qualifiedName).getOrElse("")} + |Owner: ${tableMeta.map(_.owner).getOrElse("")} + |Location: ${tableMeta.map(_.location).getOrElse("")} + |""".stripMargin) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/InsertZorderBeforeWriting.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/InsertZorderBeforeWriting.scala new file mode 100644 index 00000000000..b3f98ec6d7f --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/InsertZorderBeforeWriting.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.zorder + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, NullsLast, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.hive.execution.InsertIntoHiveTable + +import org.apache.kyuubi.sql.{KyuubiSQLConf, KyuubiSQLExtensionException} + +trait InsertZorderHelper33 extends Rule[LogicalPlan] with ZorderBuilder { + private val KYUUBI_ZORDER_ENABLED = "kyuubi.zorder.enabled" + private val KYUUBI_ZORDER_COLS = "kyuubi.zorder.cols" + + def isZorderEnabled(props: Map[String, String]): Boolean = { + props.contains(KYUUBI_ZORDER_ENABLED) && + "true".equalsIgnoreCase(props(KYUUBI_ZORDER_ENABLED)) && + props.contains(KYUUBI_ZORDER_COLS) + } + + def getZorderColumns(props: Map[String, String]): Seq[String] = { + val cols = props.get(KYUUBI_ZORDER_COLS) + assert(cols.isDefined) + cols.get.split(",").map(_.trim) + } + + def canInsertZorder(query: LogicalPlan): Boolean = query match { + case Project(_, child) => canInsertZorder(child) + // TODO: actually, we can force zorder even if existed some shuffle + case _: Sort => false + case _: RepartitionByExpression => false + case _: Repartition => false + case _ => true + } + + def insertZorder( + catalogTable: CatalogTable, + plan: LogicalPlan, + dynamicPartitionColumns: Seq[Attribute]): LogicalPlan = { + if (!canInsertZorder(plan)) { + return plan + } + val cols = getZorderColumns(catalogTable.properties) + val resolver = session.sessionState.conf.resolver + val output = plan.output + val bound = cols.flatMap(col => output.find(attr => resolver(attr.name, col))) + if (bound.size < cols.size) { + logWarning(s"target table does not contain all zorder cols: ${cols.mkString(",")}, " + + s"please check your table properties ${KYUUBI_ZORDER_COLS}.") + plan + } else { + if (conf.getConf(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED) && + conf.getConf(KyuubiSQLConf.REBALANCE_BEFORE_ZORDER)) { + throw new KyuubiSQLExtensionException(s"${KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED.key} " + + s"and ${KyuubiSQLConf.REBALANCE_BEFORE_ZORDER.key} can not be enabled together.") + } + if (conf.getConf(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED) && + dynamicPartitionColumns.nonEmpty) { + logWarning(s"Dynamic partition insertion with global sort may produce small files.") + } + + val zorderExpr = + if (bound.length == 1) { + bound + } else if (conf.getConf(KyuubiSQLConf.ZORDER_USING_ORIGINAL_ORDERING_ENABLED)) { + bound.asInstanceOf[Seq[Expression]] + } else { + buildZorder(bound) :: Nil + } + val (global, orderExprs, child) = + if (conf.getConf(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED)) { + (true, zorderExpr, plan) + } else if (conf.getConf(KyuubiSQLConf.REBALANCE_BEFORE_ZORDER)) { + val rebalanceExpr = + if (dynamicPartitionColumns.isEmpty) { + // static partition insert + bound + } else if (conf.getConf(KyuubiSQLConf.REBALANCE_ZORDER_COLUMNS_ENABLED)) { + // improve data compression ratio + dynamicPartitionColumns.asInstanceOf[Seq[Expression]] ++ bound + } else { + dynamicPartitionColumns.asInstanceOf[Seq[Expression]] + } + // for dynamic partition insert, Spark always sort the partition columns, + // so here we sort partition columns + zorder. + val rebalance = + if (dynamicPartitionColumns.nonEmpty && + conf.getConf(KyuubiSQLConf.TWO_PHASE_REBALANCE_BEFORE_ZORDER)) { + // improve compression ratio + RebalancePartitions( + rebalanceExpr, + RebalancePartitions(dynamicPartitionColumns, plan)) + } else { + RebalancePartitions(rebalanceExpr, plan) + } + (false, dynamicPartitionColumns.asInstanceOf[Seq[Expression]] ++ zorderExpr, rebalance) + } else { + (false, zorderExpr, plan) + } + val order = orderExprs.map { expr => + SortOrder(expr, Ascending, NullsLast, Seq.empty) + } + Sort(order, global, child) + } + } + + override def buildZorder(children: Seq[Expression]): ZorderBase = Zorder(children) + + def session: SparkSession + def applyInternal(plan: LogicalPlan): LogicalPlan + + final override def apply(plan: LogicalPlan): LogicalPlan = { + if (conf.getConf(KyuubiSQLConf.INSERT_ZORDER_BEFORE_WRITING)) { + applyInternal(plan) + } else { + plan + } + } +} + +case class InsertZorderBeforeWritingDatasource33(session: SparkSession) + extends InsertZorderHelper33 { + override def applyInternal(plan: LogicalPlan): LogicalPlan = plan match { + case insert: InsertIntoHadoopFsRelationCommand + if insert.query.resolved && + insert.bucketSpec.isEmpty && insert.catalogTable.isDefined && + isZorderEnabled(insert.catalogTable.get.properties) => + val dynamicPartition = + insert.partitionColumns.filterNot(attr => insert.staticPartitions.contains(attr.name)) + val newQuery = insertZorder(insert.catalogTable.get, insert.query, dynamicPartition) + if (newQuery.eq(insert.query)) { + insert + } else { + insert.copy(query = newQuery) + } + + case _ => plan + } +} + +case class InsertZorderBeforeWritingHive33(session: SparkSession) + extends InsertZorderHelper33 { + override def applyInternal(plan: LogicalPlan): LogicalPlan = plan match { + case insert: InsertIntoHiveTable + if insert.query.resolved && + insert.table.bucketSpec.isEmpty && isZorderEnabled(insert.table.properties) => + val dynamicPartition = insert.partition.filter(_._2.isEmpty).keys + .flatMap(name => insert.query.output.find(_.name == name)).toSeq + val newQuery = insertZorder(insert.table, insert.query, dynamicPartition) + if (newQuery.eq(insert.query)) { + insert + } else { + insert.copy(query = newQuery) + } + + case _ => plan + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/InsertZorderBeforeWritingBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/InsertZorderBeforeWritingBase.scala new file mode 100644 index 00000000000..2c59d148e98 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/InsertZorderBeforeWritingBase.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.zorder + +import java.util.Locale + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, NullsLast, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.hive.execution.InsertIntoHiveTable + +import org.apache.kyuubi.sql.KyuubiSQLConf + +/** + * TODO: shall we forbid zorder if it's dynamic partition inserts ? + * Insert zorder before writing datasource if the target table properties has zorder properties + */ +abstract class InsertZorderBeforeWritingDatasourceBase + extends InsertZorderHelper { + override def applyInternal(plan: LogicalPlan): LogicalPlan = plan match { + case insert: InsertIntoHadoopFsRelationCommand + if insert.query.resolved && insert.bucketSpec.isEmpty && insert.catalogTable.isDefined && + isZorderEnabled(insert.catalogTable.get.properties) => + val newQuery = insertZorder(insert.catalogTable.get, insert.query) + if (newQuery.eq(insert.query)) { + insert + } else { + insert.copy(query = newQuery) + } + case _ => plan + } +} + +/** + * TODO: shall we forbid zorder if it's dynamic partition inserts ? + * Insert zorder before writing hive if the target table properties has zorder properties + */ +abstract class InsertZorderBeforeWritingHiveBase + extends InsertZorderHelper { + override def applyInternal(plan: LogicalPlan): LogicalPlan = plan match { + case insert: InsertIntoHiveTable + if insert.query.resolved && insert.table.bucketSpec.isEmpty && + isZorderEnabled(insert.table.properties) => + val newQuery = insertZorder(insert.table, insert.query) + if (newQuery.eq(insert.query)) { + insert + } else { + insert.copy(query = newQuery) + } + case _ => plan + } +} + +trait ZorderBuilder { + def buildZorder(children: Seq[Expression]): ZorderBase +} + +trait InsertZorderHelper extends Rule[LogicalPlan] with ZorderBuilder { + private val KYUUBI_ZORDER_ENABLED = "kyuubi.zorder.enabled" + private val KYUUBI_ZORDER_COLS = "kyuubi.zorder.cols" + + def isZorderEnabled(props: Map[String, String]): Boolean = { + props.contains(KYUUBI_ZORDER_ENABLED) && + "true".equalsIgnoreCase(props(KYUUBI_ZORDER_ENABLED)) && + props.contains(KYUUBI_ZORDER_COLS) + } + + def getZorderColumns(props: Map[String, String]): Seq[String] = { + val cols = props.get(KYUUBI_ZORDER_COLS) + assert(cols.isDefined) + cols.get.split(",").map(_.trim.toLowerCase(Locale.ROOT)) + } + + def canInsertZorder(query: LogicalPlan): Boolean = query match { + case Project(_, child) => canInsertZorder(child) + // TODO: actually, we can force zorder even if existed some shuffle + case _: Sort => false + case _: RepartitionByExpression => false + case _: Repartition => false + case _ => true + } + + def insertZorder(catalogTable: CatalogTable, plan: LogicalPlan): LogicalPlan = { + if (!canInsertZorder(plan)) { + return plan + } + val cols = getZorderColumns(catalogTable.properties) + val attrs = plan.output.map(attr => (attr.name, attr)).toMap + if (cols.exists(!attrs.contains(_))) { + logWarning(s"target table does not contain all zorder cols: ${cols.mkString(",")}, " + + s"please check your table properties ${KYUUBI_ZORDER_COLS}.") + plan + } else { + val bound = cols.map(attrs(_)) + val orderExpr = + if (bound.length == 1) { + bound.head + } else { + buildZorder(bound) + } + // TODO: We can do rebalance partitions before local sort of zorder after SPARK 3.3 + // see https://github.com/apache/spark/pull/34542 + Sort( + SortOrder(orderExpr, Ascending, NullsLast, Seq.empty) :: Nil, + conf.getConf(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED), + plan) + } + } + + def applyInternal(plan: LogicalPlan): LogicalPlan + + final override def apply(plan: LogicalPlan): LogicalPlan = { + if (conf.getConf(KyuubiSQLConf.INSERT_ZORDER_BEFORE_WRITING)) { + applyInternal(plan) + } else { + plan + } + } +} + +/** + * TODO: shall we forbid zorder if it's dynamic partition inserts ? + * Insert zorder before writing datasource if the target table properties has zorder properties + */ +case class InsertZorderBeforeWritingDatasource(session: SparkSession) + extends InsertZorderBeforeWritingDatasourceBase { + override def buildZorder(children: Seq[Expression]): ZorderBase = Zorder(children) +} + +/** + * TODO: shall we forbid zorder if it's dynamic partition inserts ? + * Insert zorder before writing hive if the target table properties has zorder properties + */ +case class InsertZorderBeforeWritingHive(session: SparkSession) + extends InsertZorderBeforeWritingHiveBase { + override def buildZorder(children: Seq[Expression]): ZorderBase = Zorder(children) +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/OptimizeZorderCommandBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/OptimizeZorderCommandBase.scala new file mode 100644 index 00000000000..21d1cf2a25b --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/OptimizeZorderCommandBase.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.zorder + +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.DataWritingCommand +import org.apache.spark.sql.hive.execution.InsertIntoHiveTable + +import org.apache.kyuubi.sql.KyuubiSQLExtensionException + +/** + * A runnable command for zorder, we delegate to real command to execute + */ +abstract class OptimizeZorderCommandBase extends DataWritingCommand { + def catalogTable: CatalogTable + + override def outputColumnNames: Seq[String] = query.output.map(_.name) + + private def isHiveTable: Boolean = { + catalogTable.provider.isEmpty || + (catalogTable.provider.isDefined && "hive".equalsIgnoreCase(catalogTable.provider.get)) + } + + private def getWritingCommand(session: SparkSession): DataWritingCommand = { + // TODO: Support convert hive relation to datasource relation, can see + // [[org.apache.spark.sql.hive.RelationConversions]] + InsertIntoHiveTable( + catalogTable, + catalogTable.partitionColumnNames.map(p => (p, None)).toMap, + query, + overwrite = true, + ifPartitionNotExists = false, + outputColumnNames) + } + + override def run(session: SparkSession, child: SparkPlan): Seq[Row] = { + // TODO: Support datasource relation + // TODO: Support read and insert overwrite the same table for some table format + if (!isHiveTable) { + throw new KyuubiSQLExtensionException("only support hive table") + } + + val command = getWritingCommand(session) + command.run(session, child) + DataWritingCommand.propogateMetrics(session.sparkContext, command, metrics) + Seq.empty + } +} + +/** + * A runnable command for zorder, we delegate to real command to execute + */ +case class OptimizeZorderCommand( + catalogTable: CatalogTable, + query: LogicalPlan) + extends OptimizeZorderCommandBase { + protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = { + copy(query = newChild) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/OptimizeZorderStatementBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/OptimizeZorderStatementBase.scala new file mode 100644 index 00000000000..895f9e24be3 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/OptimizeZorderStatementBase.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.zorder + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} + +/** + * A zorder statement that contains we parsed from SQL. + * We should convert this plan to certain command at Analyzer. + */ +case class OptimizeZorderStatement( + tableIdentifier: Seq[String], + query: LogicalPlan) extends UnaryNode { + override def child: LogicalPlan = query + override def output: Seq[Attribute] = child.output + protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(query = newChild) +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ResolveZorderBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ResolveZorderBase.scala new file mode 100644 index 00000000000..9f735caa7a7 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ResolveZorderBase.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.zorder + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.rules.Rule + +import org.apache.kyuubi.sql.KyuubiSQLExtensionException + +/** + * Resolve `OptimizeZorderStatement` to `OptimizeZorderCommand` + */ +abstract class ResolveZorderBase extends Rule[LogicalPlan] { + def session: SparkSession + def buildOptimizeZorderCommand( + catalogTable: CatalogTable, + query: LogicalPlan): OptimizeZorderCommandBase + + protected def checkQueryAllowed(query: LogicalPlan): Unit = query foreach { + case Filter(condition, SubqueryAlias(_, tableRelation: HiveTableRelation)) => + if (tableRelation.partitionCols.isEmpty) { + throw new KyuubiSQLExtensionException("Filters are only supported for partitioned table") + } + + val partitionKeyIds = AttributeSet(tableRelation.partitionCols) + if (condition.references.isEmpty || !condition.references.subsetOf(partitionKeyIds)) { + throw new KyuubiSQLExtensionException("Only partition column filters are allowed") + } + + case _ => + } + + protected def getTableIdentifier(tableIdent: Seq[String]): TableIdentifier = tableIdent match { + case Seq(tbl) => TableIdentifier.apply(tbl) + case Seq(db, tbl) => TableIdentifier.apply(tbl, Some(db)) + case _ => throw new KyuubiSQLExtensionException( + "only support session catalog table, please use db.table instead") + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + case statement: OptimizeZorderStatement if statement.query.resolved => + checkQueryAllowed(statement.query) + val tableIdentifier = getTableIdentifier(statement.tableIdentifier) + val catalogTable = session.sessionState.catalog.getTableMetadata(tableIdentifier) + buildOptimizeZorderCommand(catalogTable, statement.query) + + case _ => plan + } +} + +/** + * Resolve `OptimizeZorderStatement` to `OptimizeZorderCommand` + */ +case class ResolveZorder(session: SparkSession) extends ResolveZorderBase { + override def buildOptimizeZorderCommand( + catalogTable: CatalogTable, + query: LogicalPlan): OptimizeZorderCommandBase = { + OptimizeZorderCommand(catalogTable, query) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ZorderBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ZorderBase.scala new file mode 100644 index 00000000000..e4d98ccbe84 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ZorderBase.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.zorder + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.{BinaryType, DataType} + +import org.apache.kyuubi.sql.KyuubiSQLExtensionException + +abstract class ZorderBase extends Expression { + override def foldable: Boolean = children.forall(_.foldable) + override def nullable: Boolean = false + override def dataType: DataType = BinaryType + override def prettyName: String = "zorder" + + override def checkInputDataTypes(): TypeCheckResult = { + try { + defaultNullValues + TypeCheckResult.TypeCheckSuccess + } catch { + case e: KyuubiSQLExtensionException => + TypeCheckResult.TypeCheckFailure(e.getMessage) + } + } + + @transient + private[this] lazy val defaultNullValues: Array[Any] = + children.map(_.dataType) + .map(ZorderBytesUtils.defaultValue) + .toArray + + override def eval(input: InternalRow): Any = { + val childrenValues = children.zipWithIndex.map { + case (child: Expression, index) => + val v = child.eval(input) + if (v == null) { + defaultNullValues(index) + } else { + v + } + } + ZorderBytesUtils.interleaveBits(childrenValues.toArray) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val defaultValues = ctx.addReferenceObj("defaultValues", defaultNullValues) + val values = ctx.freshName("values") + val util = ZorderBytesUtils.getClass.getName.stripSuffix("$") + val inputs = evals.zipWithIndex.map { + case (eval, index) => + s""" + |${eval.code} + |if (${eval.isNull}) { + | $values[$index] = $defaultValues[$index]; + |} else { + | $values[$index] = ${eval.value}; + |} + |""".stripMargin + } + ev.copy( + code = + code""" + |byte[] ${ev.value} = null; + |Object[] $values = new Object[${evals.length}]; + |${inputs.mkString("\n")} + |${ev.value} = $util.interleaveBits($values); + |""".stripMargin, + isNull = FalseLiteral) + } +} + +case class Zorder(children: Seq[Expression]) extends ZorderBase { + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ZorderBytesUtils.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ZorderBytesUtils.scala new file mode 100644 index 00000000000..d249f1dc32f --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/kyuubi/sql/zorder/ZorderBytesUtils.scala @@ -0,0 +1,517 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.zorder + +import java.lang.{Double => jDouble, Float => jFloat} + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.kyuubi.sql.KyuubiSQLExtensionException + +object ZorderBytesUtils { + final private val BIT_8_MASK = 1 << 7 + final private val BIT_16_MASK = 1 << 15 + final private val BIT_32_MASK = 1 << 31 + final private val BIT_64_MASK = 1L << 63 + + def interleaveBits(inputs: Array[Any]): Array[Byte] = { + inputs.length match { + // it's a more fast approach, use O(8 * 8) + // can see http://graphics.stanford.edu/~seander/bithacks.html#InterleaveTableObvious + case 1 => longToByte(toLong(inputs(0))) + case 2 => interleave2Longs(toLong(inputs(0)), toLong(inputs(1))) + case 3 => interleave3Longs(toLong(inputs(0)), toLong(inputs(1)), toLong(inputs(2))) + case 4 => + interleave4Longs(toLong(inputs(0)), toLong(inputs(1)), toLong(inputs(2)), toLong(inputs(3))) + case 5 => interleave5Longs( + toLong(inputs(0)), + toLong(inputs(1)), + toLong(inputs(2)), + toLong(inputs(3)), + toLong(inputs(4))) + case 6 => interleave6Longs( + toLong(inputs(0)), + toLong(inputs(1)), + toLong(inputs(2)), + toLong(inputs(3)), + toLong(inputs(4)), + toLong(inputs(5))) + case 7 => interleave7Longs( + toLong(inputs(0)), + toLong(inputs(1)), + toLong(inputs(2)), + toLong(inputs(3)), + toLong(inputs(4)), + toLong(inputs(5)), + toLong(inputs(6))) + case 8 => interleave8Longs( + toLong(inputs(0)), + toLong(inputs(1)), + toLong(inputs(2)), + toLong(inputs(3)), + toLong(inputs(4)), + toLong(inputs(5)), + toLong(inputs(6)), + toLong(inputs(7))) + + case _ => + // it's the default approach, use O(64 * n), n is the length of inputs + interleaveBitsDefault(inputs.map(toByteArray)) + } + } + + private def interleave2Longs(l1: Long, l2: Long): Array[Byte] = { + // output 8 * 16 bits + val result = new Array[Byte](16) + var i = 0 + while (i < 8) { + val tmp1 = ((l1 >> (i * 8)) & 0xFF).toShort + val tmp2 = ((l2 >> (i * 8)) & 0xFF).toShort + + var z = 0 + var j = 0 + while (j < 8) { + val x_masked = tmp1 & (1 << j) + val y_masked = tmp2 & (1 << j) + z |= (x_masked << j) + z |= (y_masked << (j + 1)) + j = j + 1 + } + result((7 - i) * 2 + 1) = (z & 0xFF).toByte + result((7 - i) * 2) = ((z >> 8) & 0xFF).toByte + i = i + 1 + } + result + } + + private def interleave3Longs(l1: Long, l2: Long, l3: Long): Array[Byte] = { + // output 8 * 24 bits + val result = new Array[Byte](24) + var i = 0 + while (i < 8) { + val tmp1 = ((l1 >> (i * 8)) & 0xFF).toInt + val tmp2 = ((l2 >> (i * 8)) & 0xFF).toInt + val tmp3 = ((l3 >> (i * 8)) & 0xFF).toInt + + var z = 0 + var j = 0 + while (j < 8) { + val r1_mask = tmp1 & (1 << j) + val r2_mask = tmp2 & (1 << j) + val r3_mask = tmp3 & (1 << j) + z |= (r1_mask << (2 * j)) | (r2_mask << (2 * j + 1)) | (r3_mask << (2 * j + 2)) + j = j + 1 + } + result((7 - i) * 3 + 2) = (z & 0xFF).toByte + result((7 - i) * 3 + 1) = ((z >> 8) & 0xFF).toByte + result((7 - i) * 3) = ((z >> 16) & 0xFF).toByte + i = i + 1 + } + result + } + + private def interleave4Longs(l1: Long, l2: Long, l3: Long, l4: Long): Array[Byte] = { + // output 8 * 32 bits + val result = new Array[Byte](32) + var i = 0 + while (i < 8) { + val tmp1 = ((l1 >> (i * 8)) & 0xFF).toInt + val tmp2 = ((l2 >> (i * 8)) & 0xFF).toInt + val tmp3 = ((l3 >> (i * 8)) & 0xFF).toInt + val tmp4 = ((l4 >> (i * 8)) & 0xFF).toInt + + var z = 0 + var j = 0 + while (j < 8) { + val r1_mask = tmp1 & (1 << j) + val r2_mask = tmp2 & (1 << j) + val r3_mask = tmp3 & (1 << j) + val r4_mask = tmp4 & (1 << j) + z |= (r1_mask << (3 * j)) | (r2_mask << (3 * j + 1)) | (r3_mask << (3 * j + 2)) | + (r4_mask << (3 * j + 3)) + j = j + 1 + } + result((7 - i) * 4 + 3) = (z & 0xFF).toByte + result((7 - i) * 4 + 2) = ((z >> 8) & 0xFF).toByte + result((7 - i) * 4 + 1) = ((z >> 16) & 0xFF).toByte + result((7 - i) * 4) = ((z >> 24) & 0xFF).toByte + i = i + 1 + } + result + } + + private def interleave5Longs( + l1: Long, + l2: Long, + l3: Long, + l4: Long, + l5: Long): Array[Byte] = { + // output 8 * 40 bits + val result = new Array[Byte](40) + var i = 0 + while (i < 8) { + val tmp1 = ((l1 >> (i * 8)) & 0xFF).toLong + val tmp2 = ((l2 >> (i * 8)) & 0xFF).toLong + val tmp3 = ((l3 >> (i * 8)) & 0xFF).toLong + val tmp4 = ((l4 >> (i * 8)) & 0xFF).toLong + val tmp5 = ((l5 >> (i * 8)) & 0xFF).toLong + + var z = 0L + var j = 0 + while (j < 8) { + val r1_mask = tmp1 & (1 << j) + val r2_mask = tmp2 & (1 << j) + val r3_mask = tmp3 & (1 << j) + val r4_mask = tmp4 & (1 << j) + val r5_mask = tmp5 & (1 << j) + z |= (r1_mask << (4 * j)) | (r2_mask << (4 * j + 1)) | (r3_mask << (4 * j + 2)) | + (r4_mask << (4 * j + 3)) | (r5_mask << (4 * j + 4)) + j = j + 1 + } + result((7 - i) * 5 + 4) = (z & 0xFF).toByte + result((7 - i) * 5 + 3) = ((z >> 8) & 0xFF).toByte + result((7 - i) * 5 + 2) = ((z >> 16) & 0xFF).toByte + result((7 - i) * 5 + 1) = ((z >> 24) & 0xFF).toByte + result((7 - i) * 5) = ((z >> 32) & 0xFF).toByte + i = i + 1 + } + result + } + + private def interleave6Longs( + l1: Long, + l2: Long, + l3: Long, + l4: Long, + l5: Long, + l6: Long): Array[Byte] = { + // output 8 * 48 bits + val result = new Array[Byte](48) + var i = 0 + while (i < 8) { + val tmp1 = ((l1 >> (i * 8)) & 0xFF).toLong + val tmp2 = ((l2 >> (i * 8)) & 0xFF).toLong + val tmp3 = ((l3 >> (i * 8)) & 0xFF).toLong + val tmp4 = ((l4 >> (i * 8)) & 0xFF).toLong + val tmp5 = ((l5 >> (i * 8)) & 0xFF).toLong + val tmp6 = ((l6 >> (i * 8)) & 0xFF).toLong + + var z = 0L + var j = 0 + while (j < 8) { + val r1_mask = tmp1 & (1 << j) + val r2_mask = tmp2 & (1 << j) + val r3_mask = tmp3 & (1 << j) + val r4_mask = tmp4 & (1 << j) + val r5_mask = tmp5 & (1 << j) + val r6_mask = tmp6 & (1 << j) + z |= (r1_mask << (5 * j)) | (r2_mask << (5 * j + 1)) | (r3_mask << (5 * j + 2)) | + (r4_mask << (5 * j + 3)) | (r5_mask << (5 * j + 4)) | (r6_mask << (5 * j + 5)) + j = j + 1 + } + result((7 - i) * 6 + 5) = (z & 0xFF).toByte + result((7 - i) * 6 + 4) = ((z >> 8) & 0xFF).toByte + result((7 - i) * 6 + 3) = ((z >> 16) & 0xFF).toByte + result((7 - i) * 6 + 2) = ((z >> 24) & 0xFF).toByte + result((7 - i) * 6 + 1) = ((z >> 32) & 0xFF).toByte + result((7 - i) * 6) = ((z >> 40) & 0xFF).toByte + i = i + 1 + } + result + } + + private def interleave7Longs( + l1: Long, + l2: Long, + l3: Long, + l4: Long, + l5: Long, + l6: Long, + l7: Long): Array[Byte] = { + // output 8 * 56 bits + val result = new Array[Byte](56) + var i = 0 + while (i < 8) { + val tmp1 = ((l1 >> (i * 8)) & 0xFF).toLong + val tmp2 = ((l2 >> (i * 8)) & 0xFF).toLong + val tmp3 = ((l3 >> (i * 8)) & 0xFF).toLong + val tmp4 = ((l4 >> (i * 8)) & 0xFF).toLong + val tmp5 = ((l5 >> (i * 8)) & 0xFF).toLong + val tmp6 = ((l6 >> (i * 8)) & 0xFF).toLong + val tmp7 = ((l7 >> (i * 8)) & 0xFF).toLong + + var z = 0L + var j = 0 + while (j < 8) { + val r1_mask = tmp1 & (1 << j) + val r2_mask = tmp2 & (1 << j) + val r3_mask = tmp3 & (1 << j) + val r4_mask = tmp4 & (1 << j) + val r5_mask = tmp5 & (1 << j) + val r6_mask = tmp6 & (1 << j) + val r7_mask = tmp7 & (1 << j) + z |= (r1_mask << (6 * j)) | (r2_mask << (6 * j + 1)) | (r3_mask << (6 * j + 2)) | + (r4_mask << (6 * j + 3)) | (r5_mask << (6 * j + 4)) | (r6_mask << (6 * j + 5)) | + (r7_mask << (6 * j + 6)) + j = j + 1 + } + result((7 - i) * 7 + 6) = (z & 0xFF).toByte + result((7 - i) * 7 + 5) = ((z >> 8) & 0xFF).toByte + result((7 - i) * 7 + 4) = ((z >> 16) & 0xFF).toByte + result((7 - i) * 7 + 3) = ((z >> 24) & 0xFF).toByte + result((7 - i) * 7 + 2) = ((z >> 32) & 0xFF).toByte + result((7 - i) * 7 + 1) = ((z >> 40) & 0xFF).toByte + result((7 - i) * 7) = ((z >> 48) & 0xFF).toByte + i = i + 1 + } + result + } + + private def interleave8Longs( + l1: Long, + l2: Long, + l3: Long, + l4: Long, + l5: Long, + l6: Long, + l7: Long, + l8: Long): Array[Byte] = { + // output 8 * 64 bits + val result = new Array[Byte](64) + var i = 0 + while (i < 8) { + val tmp1 = ((l1 >> (i * 8)) & 0xFF).toLong + val tmp2 = ((l2 >> (i * 8)) & 0xFF).toLong + val tmp3 = ((l3 >> (i * 8)) & 0xFF).toLong + val tmp4 = ((l4 >> (i * 8)) & 0xFF).toLong + val tmp5 = ((l5 >> (i * 8)) & 0xFF).toLong + val tmp6 = ((l6 >> (i * 8)) & 0xFF).toLong + val tmp7 = ((l7 >> (i * 8)) & 0xFF).toLong + val tmp8 = ((l8 >> (i * 8)) & 0xFF).toLong + + var z = 0L + var j = 0 + while (j < 8) { + val r1_mask = tmp1 & (1 << j) + val r2_mask = tmp2 & (1 << j) + val r3_mask = tmp3 & (1 << j) + val r4_mask = tmp4 & (1 << j) + val r5_mask = tmp5 & (1 << j) + val r6_mask = tmp6 & (1 << j) + val r7_mask = tmp7 & (1 << j) + val r8_mask = tmp8 & (1 << j) + z |= (r1_mask << (7 * j)) | (r2_mask << (7 * j + 1)) | (r3_mask << (7 * j + 2)) | + (r4_mask << (7 * j + 3)) | (r5_mask << (7 * j + 4)) | (r6_mask << (7 * j + 5)) | + (r7_mask << (7 * j + 6)) | (r8_mask << (7 * j + 7)) + j = j + 1 + } + result((7 - i) * 8 + 7) = (z & 0xFF).toByte + result((7 - i) * 8 + 6) = ((z >> 8) & 0xFF).toByte + result((7 - i) * 8 + 5) = ((z >> 16) & 0xFF).toByte + result((7 - i) * 8 + 4) = ((z >> 24) & 0xFF).toByte + result((7 - i) * 8 + 3) = ((z >> 32) & 0xFF).toByte + result((7 - i) * 8 + 2) = ((z >> 40) & 0xFF).toByte + result((7 - i) * 8 + 1) = ((z >> 48) & 0xFF).toByte + result((7 - i) * 8) = ((z >> 56) & 0xFF).toByte + i = i + 1 + } + result + } + + def interleaveBitsDefault(arrays: Array[Array[Byte]]): Array[Byte] = { + var totalLength = 0 + var maxLength = 0 + arrays.foreach { array => + totalLength += array.length + maxLength = maxLength.max(array.length * 8) + } + val result = new Array[Byte](totalLength) + var resultBit = 0 + + var bit = 0 + while (bit < maxLength) { + val bytePos = bit / 8 + val bitPos = bit % 8 + + for (arr <- arrays) { + val len = arr.length + if (bytePos < len) { + val resultBytePos = totalLength - 1 - resultBit / 8 + val resultBitPos = resultBit % 8 + result(resultBytePos) = + updatePos(result(resultBytePos), resultBitPos, arr(len - 1 - bytePos), bitPos) + resultBit += 1 + } + } + bit += 1 + } + result + } + + def updatePos(a: Byte, apos: Int, b: Byte, bpos: Int): Byte = { + var temp = (b & (1 << bpos)).toByte + if (apos > bpos) { + temp = (temp << (apos - bpos)).toByte + } else if (apos < bpos) { + temp = (temp >> (bpos - apos)).toByte + } + val atemp = (a & (1 << apos)).toByte + if (atemp == temp) { + return a + } + (a ^ (1 << apos)).toByte + } + + def toLong(a: Any): Long = { + a match { + case b: Boolean => (if (b) 1 else 0).toLong ^ BIT_64_MASK + case b: Byte => b.toLong ^ BIT_64_MASK + case s: Short => s.toLong ^ BIT_64_MASK + case i: Int => i.toLong ^ BIT_64_MASK + case l: Long => l ^ BIT_64_MASK + case f: Float => java.lang.Float.floatToRawIntBits(f).toLong ^ BIT_64_MASK + case d: Double => java.lang.Double.doubleToRawLongBits(d) ^ BIT_64_MASK + case str: UTF8String => str.getPrefix + case dec: Decimal => dec.toLong ^ BIT_64_MASK + case other: Any => + throw new KyuubiSQLExtensionException("Unsupported z-order type: " + other.getClass) + } + } + + def toByteArray(a: Any): Array[Byte] = { + a match { + case bo: Boolean => + booleanToByte(bo) + case b: Byte => + byteToByte(b) + case s: Short => + shortToByte(s) + case i: Int => + intToByte(i) + case l: Long => + longToByte(l) + case f: Float => + floatToByte(f) + case d: Double => + doubleToByte(d) + case str: UTF8String => + // truncate or padding str to 8 byte + paddingTo8Byte(str.getBytes) + case dec: Decimal => + longToByte(dec.toLong) + case other: Any => + throw new KyuubiSQLExtensionException("Unsupported z-order type: " + other.getClass) + } + } + + def booleanToByte(a: Boolean): Array[Byte] = { + if (a) { + byteToByte(1.toByte) + } else { + byteToByte(0.toByte) + } + } + + def byteToByte(a: Byte): Array[Byte] = { + val tmp = (a ^ BIT_8_MASK).toByte + Array(tmp) + } + + def shortToByte(a: Short): Array[Byte] = { + val tmp = a ^ BIT_16_MASK + Array(((tmp >> 8) & 0xFF).toByte, (tmp & 0xFF).toByte) + } + + def intToByte(a: Int): Array[Byte] = { + val result = new Array[Byte](4) + var i = 0 + val tmp = a ^ BIT_32_MASK + while (i <= 3) { + val offset = i * 8 + result(3 - i) = ((tmp >> offset) & 0xFF).toByte + i += 1 + } + result + } + + def longToByte(a: Long): Array[Byte] = { + val result = new Array[Byte](8) + var i = 0 + val tmp = a ^ BIT_64_MASK + while (i <= 7) { + val offset = i * 8 + result(7 - i) = ((tmp >> offset) & 0xFF).toByte + i += 1 + } + result + } + + def floatToByte(a: Float): Array[Byte] = { + val fi = jFloat.floatToRawIntBits(a) + intToByte(fi) + } + + def doubleToByte(a: Double): Array[Byte] = { + val dl = jDouble.doubleToRawLongBits(a) + longToByte(dl) + } + + def paddingTo8Byte(a: Array[Byte]): Array[Byte] = { + val len = a.length + if (len == 8) { + a + } else if (len > 8) { + val result = new Array[Byte](8) + System.arraycopy(a, 0, result, 0, 8) + result + } else { + val result = new Array[Byte](8) + System.arraycopy(a, 0, result, 8 - len, len) + result + } + } + + def defaultByteArrayValue(dataType: DataType): Array[Byte] = toByteArray { + defaultValue(dataType) + } + + def defaultValue(dataType: DataType): Any = { + dataType match { + case BooleanType => + true + case ByteType => + Byte.MaxValue + case ShortType => + Short.MaxValue + case IntegerType | DateType => + Int.MaxValue + case LongType | TimestampType | _: DecimalType => + Long.MaxValue + case FloatType => + Float.MaxValue + case DoubleType => + Double.MaxValue + case StringType => + // we pad string to 8 bytes so it's equal to long + UTF8String.fromBytes(longToByte(Long.MaxValue)) + case other: Any => + throw new KyuubiSQLExtensionException(s"Unsupported z-order type: ${other.catalogString}") + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/FinalStageResourceManager.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/FinalStageResourceManager.scala new file mode 100644 index 00000000000..81873476cc4 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/FinalStageResourceManager.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.annotation.tailrec +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{ExecutorAllocationClient, MapOutputTrackerMaster, SparkContext, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SortExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive._ +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.datasources.WriteFilesExec +import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec} + +import org.apache.kyuubi.sql.{KyuubiSQLConf, WriteUtils} + +/** + * This rule assumes the final write stage has less cores requirement than previous, otherwise + * this rule would take no effect. + * + * It provide a feature: + * 1. Kill redundant executors before running final write stage + */ +case class FinalStageResourceManager(session: SparkSession) + extends Rule[SparkPlan] with FinalRebalanceStageHelper { + override def apply(plan: SparkPlan): SparkPlan = { + if (!conf.getConf(KyuubiSQLConf.FINAL_WRITE_STAGE_EAGERLY_KILL_EXECUTORS_ENABLED)) { + return plan + } + + if (!WriteUtils.isWrite(session, plan)) { + return plan + } + + val sc = session.sparkContext + val dra = sc.getConf.getBoolean("spark.dynamicAllocation.enabled", false) + val coresPerExecutor = sc.getConf.getInt("spark.executor.cores", 1) + val minExecutors = sc.getConf.getInt("spark.dynamicAllocation.minExecutors", 0) + val maxExecutors = sc.getConf.getInt("spark.dynamicAllocation.maxExecutors", Int.MaxValue) + val factor = conf.getConf(KyuubiSQLConf.FINAL_WRITE_STAGE_PARTITION_FACTOR) + val hasImprovementRoom = maxExecutors - 1 > minExecutors * factor + // Fast fail if: + // 1. DRA off + // 2. only work with yarn and k8s + // 3. maxExecutors is not bigger than minExecutors * factor + if (!dra || !sc.schedulerBackend.isInstanceOf[CoarseGrainedSchedulerBackend] || + !hasImprovementRoom) { + return plan + } + + val stageOpt = findFinalRebalanceStage(plan) + if (stageOpt.isEmpty) { + return plan + } + + // It's not safe to kill executors if this plan contains table cache. + // If the executor loses then the rdd would re-compute those partition. + if (hasTableCache(plan) && + conf.getConf(KyuubiSQLConf.FINAL_WRITE_STAGE_SKIP_KILLING_EXECUTORS_FOR_TABLE_CACHE)) { + return plan + } + + // TODO: move this to query stage optimizer when updating Spark to 3.5.x + // Since we are in `prepareQueryStage`, the AQE shuffle read has not been applied. + // So we need to apply it by self. + val shuffleRead = queryStageOptimizerRules.foldLeft(stageOpt.get.asInstanceOf[SparkPlan]) { + case (latest, rule) => rule.apply(latest) + } + val (targetCores, stage) = shuffleRead match { + case AQEShuffleReadExec(stage: ShuffleQueryStageExec, partitionSpecs) => + (partitionSpecs.length, stage) + case stage: ShuffleQueryStageExec => + // we can still kill executors if no AQE shuffle read, e.g., `.repartition(2)` + (stage.shuffle.numPartitions, stage) + case _ => + // it should never happen in current Spark, but to be safe do nothing if happens + logWarning("BUG, Please report to Apache Kyuubi community") + return plan + } + // The condition whether inject custom resource profile: + // - target executors < active executors + // - active executors - target executors > min executors + val numActiveExecutors = sc.getExecutorIds().length + val targetExecutors = (math.ceil(targetCores.toFloat / coresPerExecutor) * factor).toInt + .max(1) + val hasBenefits = targetExecutors < numActiveExecutors && + (numActiveExecutors - targetExecutors) > minExecutors + logInfo(s"The snapshot of current executors view, " + + s"active executors: $numActiveExecutors, min executor: $minExecutors, " + + s"target executors: $targetExecutors, has benefits: $hasBenefits") + if (hasBenefits) { + val shuffleId = stage.plan.asInstanceOf[ShuffleExchangeExec].shuffleDependency.shuffleId + val numReduce = stage.plan.asInstanceOf[ShuffleExchangeExec].numPartitions + // Now, there is only a final rebalance stage waiting to execute and all tasks of previous + // stage are finished. Kill redundant existed executors eagerly so the tasks of final + // stage can be centralized scheduled. + killExecutors(sc, targetExecutors, shuffleId, numReduce) + } + + plan + } + + /** + * The priority of kill executors follow: + * 1. kill executor who is younger than other (The older the JIT works better) + * 2. kill executor who produces less shuffle data first + */ + private def findExecutorToKill( + sc: SparkContext, + targetExecutors: Int, + shuffleId: Int, + numReduce: Int): Seq[String] = { + val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val shuffleStatusOpt = tracker.shuffleStatuses.get(shuffleId) + if (shuffleStatusOpt.isEmpty) { + return Seq.empty + } + val shuffleStatus = shuffleStatusOpt.get + val executorToBlockSize = new mutable.HashMap[String, Long] + shuffleStatus.withMapStatuses { mapStatus => + mapStatus.foreach { status => + var i = 0 + var sum = 0L + while (i < numReduce) { + sum += status.getSizeForBlock(i) + i += 1 + } + executorToBlockSize.getOrElseUpdate(status.location.executorId, sum) + } + } + + val backend = sc.schedulerBackend.asInstanceOf[CoarseGrainedSchedulerBackend] + val executorsWithRegistrationTs = backend.getExecutorsWithRegistrationTs() + val existedExecutors = executorsWithRegistrationTs.keys.toSet + val expectedNumExecutorToKill = existedExecutors.size - targetExecutors + if (expectedNumExecutorToKill < 1) { + return Seq.empty + } + + val executorIdsToKill = new ArrayBuffer[String]() + // We first kill executor who does not hold shuffle block. It would happen because + // the last stage is running fast and finished in a short time. The existed executors are + // from previous stages that have not been killed by DRA, so we can not find it by tracking + // shuffle status. + // We should evict executors by their alive time first and retain all of executors which + // have better locality for shuffle block. + executorsWithRegistrationTs.toSeq.sortBy(_._2).foreach { case (id, _) => + if (executorIdsToKill.length < expectedNumExecutorToKill && + !executorToBlockSize.contains(id)) { + executorIdsToKill.append(id) + } + } + + // Evict the rest executors according to the shuffle block size + executorToBlockSize.toSeq.sortBy(_._2).foreach { case (id, _) => + if (executorIdsToKill.length < expectedNumExecutorToKill && existedExecutors.contains(id)) { + executorIdsToKill.append(id) + } + } + + executorIdsToKill.toSeq + } + + private def killExecutors( + sc: SparkContext, + targetExecutors: Int, + shuffleId: Int, + numReduce: Int): Unit = { + val executorAllocationClient = sc.schedulerBackend.asInstanceOf[ExecutorAllocationClient] + + val executorsToKill = + if (conf.getConf(KyuubiSQLConf.FINAL_WRITE_STAGE_EAGERLY_KILL_EXECUTORS_KILL_ALL)) { + executorAllocationClient.getExecutorIds() + } else { + findExecutorToKill(sc, targetExecutors, shuffleId, numReduce) + } + logInfo(s"Request to kill executors, total count ${executorsToKill.size}, " + + s"[${executorsToKill.mkString(", ")}].") + if (executorsToKill.isEmpty) { + return + } + + // Note, `SparkContext#killExecutors` does not allow with DRA enabled, + // see `https://github.com/apache/spark/pull/20604`. + // It may cause the status in `ExecutorAllocationManager` inconsistent with + // `CoarseGrainedSchedulerBackend` for a while. But it should be synchronous finally. + // + // We should adjust target num executors, otherwise `YarnAllocator` might re-request original + // target executors if DRA has not updated target executors yet. + // Note, DRA would re-adjust executors if there are more tasks to be executed, so we are safe. + // + // * We kill executor + // * YarnAllocator re-request target executors + // * DRA can not release executors since they are new added + // ----------------------------------------------------------------> timeline + executorAllocationClient.killExecutors( + executorIds = executorsToKill, + adjustTargetNumExecutors = true, + countFailures = false, + force = false) + + FinalStageResourceManager.getAdjustedTargetExecutors(sc) + .filter(_ < targetExecutors).foreach { adjustedExecutors => + val delta = targetExecutors - adjustedExecutors + logInfo(s"Target executors after kill ($adjustedExecutors) is lower than required " + + s"($targetExecutors). Requesting $delta additional executor(s).") + executorAllocationClient.requestExecutors(delta) + } + } + + @transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq( + OptimizeSkewInRebalancePartitions, + CoalesceShufflePartitions(session), + OptimizeShuffleWithLocalRead) +} + +object FinalStageResourceManager extends Logging { + + private[sql] def getAdjustedTargetExecutors(sc: SparkContext): Option[Int] = { + sc.schedulerBackend match { + case schedulerBackend: CoarseGrainedSchedulerBackend => + try { + val field = classOf[CoarseGrainedSchedulerBackend] + .getDeclaredField("requestedTotalExecutorsPerResourceProfile") + field.setAccessible(true) + schedulerBackend.synchronized { + val requestedTotalExecutorsPerResourceProfile = + field.get(schedulerBackend).asInstanceOf[mutable.HashMap[ResourceProfile, Int]] + val defaultRp = sc.resourceProfileManager.defaultResourceProfile + requestedTotalExecutorsPerResourceProfile.get(defaultRp) + } + } catch { + case e: Exception => + logWarning("Failed to get requestedTotalExecutors of Default ResourceProfile", e) + None + } + case _ => None + } + } +} + +trait FinalRebalanceStageHelper extends AdaptiveSparkPlanHelper { + @tailrec + final protected def findFinalRebalanceStage(plan: SparkPlan): Option[ShuffleQueryStageExec] = { + plan match { + case write: DataWritingCommandExec => findFinalRebalanceStage(write.child) + case write: V2TableWriteExec => findFinalRebalanceStage(write.child) + case write: WriteFilesExec => findFinalRebalanceStage(write.child) + case p: ProjectExec => findFinalRebalanceStage(p.child) + case f: FilterExec => findFinalRebalanceStage(f.child) + case s: SortExec if !s.global => findFinalRebalanceStage(s.child) + case stage: ShuffleQueryStageExec + if stage.isMaterialized && stage.mapStats.isDefined && + stage.plan.isInstanceOf[ShuffleExchangeExec] && + stage.plan.asInstanceOf[ShuffleExchangeExec].shuffleOrigin != ENSURE_REQUIREMENTS => + Some(stage) + case _ => None + } + } + + final protected def hasTableCache(plan: SparkPlan): Boolean = { + find(plan) { + case _: InMemoryTableScanExec => true + case _ => false + }.isDefined + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/InjectCustomResourceProfile.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/InjectCustomResourceProfile.scala new file mode 100644 index 00000000000..64421d6bfab --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/InjectCustomResourceProfile.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{CustomResourceProfileExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive._ + +import org.apache.kyuubi.sql.{KyuubiSQLConf, WriteUtils} + +/** + * Inject custom resource profile for final write stage, so we can specify custom + * executor resource configs. + */ +case class InjectCustomResourceProfile(session: SparkSession) + extends Rule[SparkPlan] with FinalRebalanceStageHelper { + override def apply(plan: SparkPlan): SparkPlan = { + if (!conf.getConf(KyuubiSQLConf.FINAL_WRITE_STAGE_RESOURCE_ISOLATION_ENABLED)) { + return plan + } + + if (!WriteUtils.isWrite(session, plan)) { + return plan + } + + val stage = findFinalRebalanceStage(plan) + if (stage.isEmpty) { + return plan + } + + // TODO: Ideally, We can call `CoarseGrainedSchedulerBackend.requestTotalExecutors` eagerly + // to reduce the task submit pending time, but it may lose task locality. + // + // By default, it would request executors when catch stage submit event. + injectCustomResourceProfile(plan, stage.get.id) + } + + private def injectCustomResourceProfile(plan: SparkPlan, id: Int): SparkPlan = { + plan match { + case stage: ShuffleQueryStageExec if stage.id == id => + CustomResourceProfileExec(stage) + case _ => plan.mapChildren(child => injectCustomResourceProfile(child, id)) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala new file mode 100644 index 00000000000..ce496eb474c --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/PruneFileSourcePartitionHelper.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.types.StructType + +trait PruneFileSourcePartitionHelper extends PredicateHelper { + + def getPartitionKeyFiltersAndDataFilters( + sparkSession: SparkSession, + relation: LeafNode, + partitionSchema: StructType, + filters: Seq[Expression], + output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = { + val normalizedFilters = DataSourceStrategy.normalizeExprs( + filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), + output) + val partitionColumns = + relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver) + val partitionSet = AttributeSet(partitionColumns) + val (partitionFilters, dataFilters) = normalizedFilters.partition(f => + f.references.subsetOf(partitionSet)) + val extraPartitionFilter = + dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet)) + + (ExpressionSet(partitionFilters ++ extraPartitionFilter), dataFilters) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/execution/CustomResourceProfileExec.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/execution/CustomResourceProfileExec.scala new file mode 100644 index 00000000000..3698140fbd0 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/main/scala/org/apache/spark/sql/execution/CustomResourceProfileExec.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.network.util.{ByteUnit, JavaUtils} +import org.apache.spark.rdd.RDD +import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfileBuilder} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + +import org.apache.kyuubi.sql.KyuubiSQLConf._ + +/** + * This node wraps the final executed plan and inject custom resource profile to the RDD. + * It assumes that, the produced RDD would create the `ResultStage` in `DAGScheduler`, + * so it makes resource isolation between previous and final stage. + * + * Note that, Spark does not support config `minExecutors` for each resource profile. + * Which means, it would retain `minExecutors` for each resource profile. + * So, suggest set `spark.dynamicAllocation.minExecutors` to 0 if enable this feature. + */ +case class CustomResourceProfileExec(child: SparkPlan) extends UnaryExecNode { + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def supportsColumnar: Boolean = child.supportsColumnar + override def supportsRowBased: Boolean = child.supportsRowBased + override protected def doCanonicalize(): SparkPlan = child.canonicalized + + private val executorCores = conf.getConf(FINAL_WRITE_STAGE_EXECUTOR_CORES).getOrElse( + sparkContext.getConf.getInt("spark.executor.cores", 1)) + private val executorMemory = conf.getConf(FINAL_WRITE_STAGE_EXECUTOR_MEMORY).getOrElse( + sparkContext.getConf.get("spark.executor.memory", "2G")) + private val executorMemoryOverhead = + conf.getConf(FINAL_WRITE_STAGE_EXECUTOR_MEMORY_OVERHEAD) + .getOrElse(sparkContext.getConf.get("spark.executor.memoryOverhead", "1G")) + private val executorOffHeapMemory = conf.getConf(FINAL_WRITE_STAGE_EXECUTOR_OFF_HEAP_MEMORY) + + override lazy val metrics: Map[String, SQLMetric] = { + val base = Map( + "executorCores" -> SQLMetrics.createMetric(sparkContext, "executor cores"), + "executorMemory" -> SQLMetrics.createMetric(sparkContext, "executor memory (MiB)"), + "executorMemoryOverhead" -> SQLMetrics.createMetric( + sparkContext, + "executor memory overhead (MiB)")) + val addition = executorOffHeapMemory.map(_ => + "executorOffHeapMemory" -> + SQLMetrics.createMetric(sparkContext, "executor off heap memory (MiB)")).toMap + base ++ addition + } + + private def wrapResourceProfile[T](rdd: RDD[T]): RDD[T] = { + if (Utils.isTesting) { + // do nothing for local testing + return rdd + } + + metrics("executorCores") += executorCores + metrics("executorMemory") += JavaUtils.byteStringAs(executorMemory, ByteUnit.MiB) + metrics("executorMemoryOverhead") += JavaUtils.byteStringAs( + executorMemoryOverhead, + ByteUnit.MiB) + executorOffHeapMemory.foreach(m => + metrics("executorOffHeapMemory") += JavaUtils.byteStringAs(m, ByteUnit.MiB)) + + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + + val resourceProfileBuilder = new ResourceProfileBuilder() + val executorResourceRequests = new ExecutorResourceRequests() + executorResourceRequests.cores(executorCores) + executorResourceRequests.memory(executorMemory) + executorResourceRequests.memoryOverhead(executorMemoryOverhead) + executorOffHeapMemory.foreach(executorResourceRequests.offHeapMemory) + resourceProfileBuilder.require(executorResourceRequests) + rdd.withResources(resourceProfileBuilder.build()) + rdd + } + + override protected def doExecute(): RDD[InternalRow] = { + val rdd = child.execute() + wrapResourceProfile(rdd) + } + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { + val rdd = child.executeColumnar() + wrapResourceProfile(rdd) + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = { + this.copy(child = newChild) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/resources/log4j2-test.xml b/extensions/spark/kyuubi-extension-spark-3-5/src/test/resources/log4j2-test.xml new file mode 100644 index 00000000000..bfc40dd6df4 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/resources/log4j2-test.xml @@ -0,0 +1,43 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/DropIgnoreNonexistentSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/DropIgnoreNonexistentSuite.scala new file mode 100644 index 00000000000..bbc61fb4408 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/DropIgnoreNonexistentSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical.{DropNamespace, NoopCommand} +import org.apache.spark.sql.execution.command._ + +import org.apache.kyuubi.sql.KyuubiSQLConf + +class DropIgnoreNonexistentSuite extends KyuubiSparkSQLExtensionTest { + + test("drop ignore nonexistent") { + withSQLConf(KyuubiSQLConf.DROP_IGNORE_NONEXISTENT.key -> "true") { + // drop nonexistent database + val df1 = sql("DROP DATABASE nonexistent_database") + assert(df1.queryExecution.analyzed.asInstanceOf[DropNamespace].ifExists == true) + + // drop nonexistent function + val df4 = sql("DROP FUNCTION nonexistent_function") + assert(df4.queryExecution.analyzed.isInstanceOf[NoopCommand]) + + // drop nonexistent PARTITION + withTable("test") { + sql("CREATE TABLE IF NOT EXISTS test(i int) PARTITIONED BY (p int)") + val df5 = sql("ALTER TABLE test DROP PARTITION (p = 1)") + assert(df5.queryExecution.analyzed + .asInstanceOf[AlterTableDropPartitionCommand].ifExists == true) + } + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/FinalStageConfigIsolationSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/FinalStageConfigIsolationSuite.scala new file mode 100644 index 00000000000..96c8ae6e8b0 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/FinalStageConfigIsolationSuite.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, QueryStageExec} +import org.apache.spark.sql.internal.SQLConf + +import org.apache.kyuubi.sql.{FinalStageConfigIsolation, KyuubiSQLConf} + +class FinalStageConfigIsolationSuite extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + test("final stage config set reset check") { + withSQLConf( + KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION.key -> "true", + KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION_WRITE_ONLY.key -> "false", + "spark.sql.finalStage.adaptive.coalescePartitions.minPartitionNum" -> "1", + "spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes" -> "100") { + // use loop to double check final stage config doesn't affect the sql query each other + (1 to 3).foreach { _ => + sql("SELECT COUNT(*) FROM VALUES(1) as t(c)").collect() + assert(spark.sessionState.conf.getConfString( + "spark.sql.previousStage.adaptive.coalescePartitions.minPartitionNum") === + FinalStageConfigIsolation.INTERNAL_UNSET_CONFIG_TAG) + assert(spark.sessionState.conf.getConfString( + "spark.sql.adaptive.coalescePartitions.minPartitionNum") === + "1") + assert(spark.sessionState.conf.getConfString( + "spark.sql.finalStage.adaptive.coalescePartitions.minPartitionNum") === + "1") + + // 64MB + assert(spark.sessionState.conf.getConfString( + "spark.sql.previousStage.adaptive.advisoryPartitionSizeInBytes") === + "67108864b") + assert(spark.sessionState.conf.getConfString( + "spark.sql.adaptive.advisoryPartitionSizeInBytes") === + "100") + assert(spark.sessionState.conf.getConfString( + "spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes") === + "100") + } + + sql("SET spark.sql.adaptive.advisoryPartitionSizeInBytes=1") + assert(spark.sessionState.conf.getConfString( + "spark.sql.adaptive.advisoryPartitionSizeInBytes") === + "1") + assert(!spark.sessionState.conf.contains( + "spark.sql.previousStage.adaptive.advisoryPartitionSizeInBytes")) + + sql("SET a=1") + assert(spark.sessionState.conf.getConfString("a") === "1") + + sql("RESET spark.sql.adaptive.coalescePartitions.minPartitionNum") + assert(!spark.sessionState.conf.contains( + "spark.sql.adaptive.coalescePartitions.minPartitionNum")) + assert(!spark.sessionState.conf.contains( + "spark.sql.previousStage.adaptive.coalescePartitions.minPartitionNum")) + + sql("RESET a") + assert(!spark.sessionState.conf.contains("a")) + } + } + + test("final stage config isolation") { + def checkPartitionNum( + sqlString: String, + previousPartitionNum: Int, + finalPartitionNum: Int): Unit = { + val df = sql(sqlString) + df.collect() + val shuffleReaders = collect(df.queryExecution.executedPlan) { + case customShuffleReader: AQEShuffleReadExec => customShuffleReader + } + assert(shuffleReaders.nonEmpty) + // reorder stage by stage id to ensure we get the right stage + val sortedShuffleReaders = shuffleReaders.sortWith { + case (s1, s2) => + s1.child.asInstanceOf[QueryStageExec].id < s2.child.asInstanceOf[QueryStageExec].id + } + if (sortedShuffleReaders.length > 1) { + assert(sortedShuffleReaders.head.partitionSpecs.length === previousPartitionNum) + } + assert(sortedShuffleReaders.last.partitionSpecs.length === finalPartitionNum) + assert(df.rdd.partitions.length === finalPartitionNum) + } + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION.key -> "true", + KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION_WRITE_ONLY.key -> "false", + "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> "1", + "spark.sql.adaptive.coalescePartitions.minPartitionSize" -> "1", + "spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes" -> "10000000") { + + // use loop to double check final stage config doesn't affect the sql query each other + (1 to 3).foreach { _ => + checkPartitionNum( + "SELECT c1, count(*) FROM t1 GROUP BY c1", + 1, + 1) + + checkPartitionNum( + "SELECT c2, count(*) FROM (SELECT c1, count(*) as c2 FROM t1 GROUP BY c1) GROUP BY c2", + 3, + 1) + + checkPartitionNum( + "SELECT t1.c1, count(*) FROM t1 JOIN t2 ON t1.c2 = t2.c2 GROUP BY t1.c1", + 3, + 1) + + checkPartitionNum( + """ + | SELECT /*+ REPARTITION */ + | t1.c1, count(*) FROM t1 + | JOIN t2 ON t1.c2 = t2.c2 + | JOIN t3 ON t1.c1 = t3.c1 + | GROUP BY t1.c1 + |""".stripMargin, + 3, + 1) + + // one shuffle reader + checkPartitionNum( + """ + | SELECT /*+ BROADCAST(t1) */ + | t1.c1, t2.c2 FROM t1 + | JOIN t2 ON t1.c2 = t2.c2 + | DISTRIBUTE BY c1 + |""".stripMargin, + 1, + 1) + + // test ReusedExchange + checkPartitionNum( + """ + |SELECT /*+ REPARTITION */ t0.c2 FROM ( + |SELECT t1.c1, (count(*) + c1) as c2 FROM t1 GROUP BY t1.c1 + |) t0 JOIN ( + |SELECT t1.c1, (count(*) + c1) as c2 FROM t1 GROUP BY t1.c1 + |) t1 ON t0.c2 = t1.c2 + |""".stripMargin, + 3, + 1) + + // one shuffle reader + checkPartitionNum( + """ + |SELECT t0.c1 FROM ( + |SELECT t1.c1 FROM t1 GROUP BY t1.c1 + |) t0 JOIN ( + |SELECT t1.c1 FROM t1 GROUP BY t1.c1 + |) t1 ON t0.c1 = t1.c1 + |""".stripMargin, + 1, + 1) + } + } + } + + test("final stage config isolation write only") { + withSQLConf( + KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION.key -> "true", + KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION_WRITE_ONLY.key -> "true", + "spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes" -> "7") { + sql("set spark.sql.adaptive.advisoryPartitionSizeInBytes=5") + sql("SELECT * FROM t1").count() + assert(spark.conf.getOption("spark.sql.adaptive.advisoryPartitionSizeInBytes") + .contains("5")) + + withTable("tmp") { + sql("CREATE TABLE t1 USING PARQUET SELECT /*+ repartition */ 1 AS c1, 'a' AS c2") + assert(spark.conf.getOption("spark.sql.adaptive.advisoryPartitionSizeInBytes") + .contains("7")) + } + + sql("SELECT * FROM t1").count() + assert(spark.conf.getOption("spark.sql.adaptive.advisoryPartitionSizeInBytes") + .contains("5")) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/FinalStageResourceManagerSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/FinalStageResourceManagerSuite.scala new file mode 100644 index 00000000000..4b9991ef6f2 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/FinalStageResourceManagerSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkConf +import org.scalatest.time.{Minutes, Span} + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.tags.SparkLocalClusterTest + +@SparkLocalClusterTest +class FinalStageResourceManagerSuite extends KyuubiSparkSQLExtensionTest { + + override def sparkConf(): SparkConf = { + // It is difficult to run spark in local-cluster mode when spark.testing is set. + sys.props.remove("spark.testing") + + super.sparkConf().set("spark.master", "local-cluster[3, 1, 1024]") + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.initialExecutors", "3") + .set("spark.dynamicAllocation.minExecutors", "1") + .set("spark.dynamicAllocation.shuffleTracking.enabled", "true") + .set(KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION.key, "true") + .set(KyuubiSQLConf.FINAL_WRITE_STAGE_EAGERLY_KILL_EXECUTORS_ENABLED.key, "true") + } + + test("[KYUUBI #5136][Bug] Final Stage hangs forever") { + // Prerequisite to reproduce the bug: + // 1. Dynamic allocation is enabled. + // 2. Dynamic allocation min executors is 1. + // 3. target executors < active executors. + // 4. No active executor is left after FinalStageResourceManager killed executors. + // This is possible because FinalStageResourceManager retained executors may already be + // requested to be killed but not died yet. + // 5. Final Stage required executors is 1. + withSQLConf( + (KyuubiSQLConf.FINAL_WRITE_STAGE_EAGERLY_KILL_EXECUTORS_KILL_ALL.key, "true")) { + withTable("final_stage") { + eventually(timeout(Span(10, Minutes))) { + sql( + "CREATE TABLE final_stage AS SELECT id, count(*) as num FROM (SELECT 0 id) GROUP BY id") + } + assert(FinalStageResourceManager.getAdjustedTargetExecutors(spark.sparkContext).get == 1) + } + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InjectResourceProfileSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InjectResourceProfileSuite.scala new file mode 100644 index 00000000000..b0767b18708 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InjectResourceProfileSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} +import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate + +import org.apache.kyuubi.sql.KyuubiSQLConf + +class InjectResourceProfileSuite extends KyuubiSparkSQLExtensionTest { + private def checkCustomResourceProfile(sqlString: String, exists: Boolean): Unit = { + @volatile var lastEvent: SparkListenerSQLAdaptiveExecutionUpdate = null + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case e: SparkListenerSQLAdaptiveExecutionUpdate => lastEvent = e + case _ => + } + } + } + + spark.sparkContext.addSparkListener(listener) + try { + sql(sqlString).collect() + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(lastEvent != null) + var current = lastEvent.sparkPlanInfo + var shouldStop = false + while (!shouldStop) { + if (current.nodeName != "CustomResourceProfile") { + if (current.children.isEmpty) { + assert(!exists) + shouldStop = true + } else { + current = current.children.head + } + } else { + assert(exists) + shouldStop = true + } + } + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + + test("Inject resource profile") { + withTable("t") { + withSQLConf( + "spark.sql.adaptive.forceApply" -> "true", + KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION.key -> "true", + KyuubiSQLConf.FINAL_WRITE_STAGE_RESOURCE_ISOLATION_ENABLED.key -> "true") { + + sql("CREATE TABLE t (c1 int, c2 string) USING PARQUET") + + checkCustomResourceProfile("INSERT INTO TABLE t VALUES(1, 'a')", false) + checkCustomResourceProfile("SELECT 1", false) + checkCustomResourceProfile( + "INSERT INTO TABLE t SELECT /*+ rebalance */ * FROM VALUES(1, 'a')", + true) + } + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InsertShuffleNodeBeforeJoinSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InsertShuffleNodeBeforeJoinSuite.scala new file mode 100644 index 00000000000..f0d38465734 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InsertShuffleNodeBeforeJoinSuite.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +class InsertShuffleNodeBeforeJoinSuite extends InsertShuffleNodeBeforeJoinSuiteBase diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InsertShuffleNodeBeforeJoinSuiteBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InsertShuffleNodeBeforeJoinSuiteBase.scala new file mode 100644 index 00000000000..c657dee49f3 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/InsertShuffleNodeBeforeJoinSuiteBase.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike} +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} + +import org.apache.kyuubi.sql.KyuubiSQLConf + +trait InsertShuffleNodeBeforeJoinSuiteBase extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + override def sparkConf(): SparkConf = { + super.sparkConf() + .set( + StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, + "org.apache.kyuubi.sql.KyuubiSparkSQLCommonExtension") + } + + test("force shuffle before join") { + def checkShuffleNodeNum(sqlString: String, num: Int): Unit = { + var expectedResult: Seq[Row] = Seq.empty + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + expectedResult = sql(sqlString).collect() + } + val df = sql(sqlString) + checkAnswer(df, expectedResult) + assert( + collect(df.queryExecution.executedPlan) { + case shuffle: ShuffleExchangeLike if shuffle.shuffleOrigin == ENSURE_REQUIREMENTS => + shuffle + }.size == num) + } + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + KyuubiSQLConf.FORCE_SHUFFLE_BEFORE_JOIN.key -> "true") { + Seq("SHUFFLE_HASH", "MERGE").foreach { joinHint => + // positive case + checkShuffleNodeNum( + s""" + |SELECT /*+ $joinHint(t2, t3) */ t1.c1, t1.c2, t2.c1, t3.c1 from t1 + | JOIN t2 ON t1.c1 = t2.c1 + | JOIN t3 ON t1.c1 = t3.c1 + | """.stripMargin, + 4) + + // negative case + checkShuffleNodeNum( + s""" + |SELECT /*+ $joinHint(t2, t3) */ t1.c1, t1.c2, t2.c1, t3.c1 from t1 + | JOIN t2 ON t1.c1 = t2.c1 + | JOIN t3 ON t1.c2 = t3.c2 + | """.stripMargin, + 4) + } + + checkShuffleNodeNum( + """ + |SELECT t1.c1, t2.c1, t3.c2 from t1 + | JOIN t2 ON t1.c1 = t2.c1 + | JOIN ( + | SELECT c2, count(*) FROM t1 GROUP BY c2 + | ) t3 ON t1.c1 = t3.c2 + | """.stripMargin, + 5) + + checkShuffleNodeNum( + """ + |SELECT t1.c1, t2.c1, t3.c1 from t1 + | JOIN t2 ON t1.c1 = t2.c1 + | JOIN ( + | SELECT c1, count(*) FROM t1 GROUP BY c1 + | ) t3 ON t1.c1 = t3.c1 + | """.stripMargin, + 5) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala new file mode 100644 index 00000000000..dd9ffbf169e --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.command.{DataWritingCommand, DataWritingCommandExec} +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.test.SQLTestData.TestData +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.util.QueryExecutionListener +import org.apache.spark.util.Utils + +import org.apache.kyuubi.sql.KyuubiSQLConf + +trait KyuubiSparkSQLExtensionTest extends QueryTest + with SQLTestUtils + with AdaptiveSparkPlanHelper { + sys.props.put("spark.testing", "1") + + private var _spark: Option[SparkSession] = None + protected def spark: SparkSession = _spark.getOrElse { + throw new RuntimeException("test spark session don't initial before using it.") + } + + override protected def beforeAll(): Unit = { + if (_spark.isEmpty) { + _spark = Option(SparkSession.builder() + .master("local[1]") + .config(sparkConf) + .enableHiveSupport() + .getOrCreate()) + } + super.beforeAll() + } + + override protected def afterAll(): Unit = { + super.afterAll() + cleanupData() + _spark.foreach(_.stop) + } + + protected def setupData(): Unit = { + val self = spark + import self.implicits._ + spark.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString)), + 10) + .toDF("c1", "c2").createOrReplaceTempView("t1") + spark.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString)), + 5) + .toDF("c1", "c2").createOrReplaceTempView("t2") + spark.sparkContext.parallelize( + (1 to 50).map(i => TestData(i, i.toString)), + 2) + .toDF("c1", "c2").createOrReplaceTempView("t3") + } + + private def cleanupData(): Unit = { + spark.sql("DROP VIEW IF EXISTS t1") + spark.sql("DROP VIEW IF EXISTS t2") + spark.sql("DROP VIEW IF EXISTS t3") + } + + def sparkConf(): SparkConf = { + val basePath = Utils.createTempDir() + "/" + getClass.getCanonicalName + val metastorePath = basePath + "/metastore_db" + val warehousePath = basePath + "/warehouse" + new SparkConf() + .set( + StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, + "org.apache.kyuubi.sql.KyuubiSparkSQLExtension") + .set(KyuubiSQLConf.SQL_CLASSIFICATION_ENABLED.key, "true") + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set("spark.hadoop.hive.exec.dynamic.partition.mode", "nonstrict") + .set("spark.hadoop.hive.metastore.client.capability.check", "false") + .set( + ConfVars.METASTORECONNECTURLKEY.varname, + s"jdbc:derby:;databaseName=$metastorePath;create=true") + .set(StaticSQLConf.WAREHOUSE_PATH, warehousePath) + .set("spark.ui.enabled", "false") + } + + def withListener(sqlString: String)(callback: DataWritingCommand => Unit): Unit = { + withListener(sql(sqlString))(callback) + } + + def withListener(df: => DataFrame)(callback: DataWritingCommand => Unit): Unit = { + val listener = new QueryExecutionListener { + override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + qe.executedPlan match { + case write: DataWritingCommandExec => callback(write.cmd) + case _ => + } + } + } + spark.listenerManager.register(listener) + try { + df.collect() + sparkContext.listenerBus.waitUntilEmpty() + } finally { + spark.listenerManager.unregister(listener) + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala new file mode 100644 index 00000000000..1d9630f4937 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions, Sort} +import org.apache.spark.sql.execution.command.DataWritingCommand +import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.execution.InsertIntoHiveTable + +import org.apache.kyuubi.sql.KyuubiSQLConf + +class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { + + test("check rebalance exists") { + def check(df: => DataFrame, expectedRebalanceNum: Int = 1): Unit = { + withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") { + withListener(df) { write => + assert(write.collect { + case r: RebalancePartitions => r + }.size == expectedRebalanceNum) + } + } + withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "false") { + withListener(df) { write => + assert(write.collect { + case r: RebalancePartitions => r + }.isEmpty) + } + } + } + + // It's better to set config explicitly in case of we change the default value. + withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true") { + Seq("USING PARQUET", "").foreach { storage => + withTable("tmp1") { + sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)") + check(sql("INSERT INTO TABLE tmp1 PARTITION(c2='a') " + + "SELECT * FROM VALUES(1),(2) AS t(c1)")) + } + + withTable("tmp1", "tmp2") { + sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)") + sql(s"CREATE TABLE tmp2 (c1 int) $storage PARTITIONED BY (c2 string)") + check( + sql( + """FROM VALUES(1),(2) + |INSERT INTO TABLE tmp1 PARTITION(c2='a') SELECT * + |INSERT INTO TABLE tmp2 PARTITION(c2='a') SELECT * + |""".stripMargin), + 2) + } + + withTable("tmp1") { + sql(s"CREATE TABLE tmp1 (c1 int) $storage") + check(sql("INSERT INTO TABLE tmp1 SELECT * FROM VALUES(1),(2),(3) AS t(c1)")) + } + + withTable("tmp1", "tmp2") { + sql(s"CREATE TABLE tmp1 (c1 int) $storage") + sql(s"CREATE TABLE tmp2 (c1 int) $storage") + check( + sql( + """FROM VALUES(1),(2),(3) + |INSERT INTO TABLE tmp1 SELECT * + |INSERT INTO TABLE tmp2 SELECT * + |""".stripMargin), + 2) + } + + withTable("tmp1") { + sql(s"CREATE TABLE tmp1 $storage AS SELECT * FROM VALUES(1),(2),(3) AS t(c1)") + } + + withTable("tmp1") { + sql(s"CREATE TABLE tmp1 $storage PARTITIONED BY(c2) AS " + + s"SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)") + } + } + } + } + + test("check rebalance does not exists") { + def check(df: DataFrame): Unit = { + withListener(df) { write => + assert(write.collect { + case r: RebalancePartitions => r + }.isEmpty) + } + } + + withSQLConf( + KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true", + KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") { + // test no write command + check(sql("SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")) + check(sql("SELECT count(*) FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")) + + // test not supported plan + withTable("tmp1") { + sql(s"CREATE TABLE tmp1 (c1 int) PARTITIONED BY (c2 string)") + check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " + + "SELECT /*+ repartition(10) */ * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")) + check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " + + "SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2) ORDER BY c1")) + check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " + + "SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2) LIMIT 10")) + } + } + + withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "false") { + Seq("USING PARQUET", "").foreach { storage => + withTable("tmp1") { + sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)") + check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " + + "SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")) + } + + withTable("tmp1") { + sql(s"CREATE TABLE tmp1 (c1 int) $storage") + check(sql("INSERT INTO TABLE tmp1 SELECT * FROM VALUES(1),(2),(3) AS t(c1)")) + } + } + } + } + + test("test dynamic partition write") { + def checkRepartitionExpression(sqlString: String): Unit = { + withListener(sqlString) { write => + assert(write.isInstanceOf[InsertIntoHiveTable]) + assert(write.collect { + case r: RebalancePartitions if r.partitionExpressions.size == 1 => + assert(r.partitionExpressions.head.asInstanceOf[Attribute].name === "c2") + r + }.size == 1) + } + } + + withSQLConf( + KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true", + KyuubiSQLConf.DYNAMIC_PARTITION_INSERTION_REPARTITION_NUM.key -> "2", + KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") { + Seq("USING PARQUET", "").foreach { storage => + withTable("tmp1") { + sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)") + checkRepartitionExpression("INSERT INTO TABLE tmp1 SELECT 1 as c1, 'a' as c2 ") + } + + withTable("tmp1") { + checkRepartitionExpression( + "CREATE TABLE tmp1 PARTITIONED BY(C2) SELECT 1 as c1, 'a' as c2") + } + } + } + } + + test("OptimizedCreateHiveTableAsSelectCommand") { + withSQLConf( + HiveUtils.CONVERT_METASTORE_PARQUET.key -> "true", + HiveUtils.CONVERT_METASTORE_CTAS.key -> "true", + KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") { + withTable("t") { + withListener("CREATE TABLE t STORED AS parquet AS SELECT 1 as a") { write => + assert(write.isInstanceOf[InsertIntoHadoopFsRelationCommand]) + assert(write.collect { + case _: RebalancePartitions => true + }.size == 1) + } + } + } + } + + test("Infer rebalance and sorder orders") { + def checkShuffleAndSort(dataWritingCommand: LogicalPlan, sSize: Int, rSize: Int): Unit = { + assert(dataWritingCommand.isInstanceOf[DataWritingCommand]) + val plan = dataWritingCommand.asInstanceOf[DataWritingCommand].query + assert(plan.collect { + case s: Sort => s + }.size == sSize) + assert(plan.collect { + case r: RebalancePartitions if r.partitionExpressions.size == rSize => r + }.nonEmpty || rSize == 0) + } + + withView("v") { + withTable("t", "input1", "input2") { + withSQLConf(KyuubiSQLConf.INFER_REBALANCE_AND_SORT_ORDERS.key -> "true") { + sql(s"CREATE TABLE t (c1 int, c2 long) USING PARQUET PARTITIONED BY (p string)") + sql(s"CREATE TABLE input1 USING PARQUET AS SELECT * FROM VALUES(1,2),(1,3)") + sql(s"CREATE TABLE input2 USING PARQUET AS SELECT * FROM VALUES(1,3),(1,3)") + sql(s"CREATE VIEW v as SELECT col1, count(*) as col2 FROM input1 GROUP BY col1") + + val df0 = sql( + s""" + |INSERT INTO TABLE t PARTITION(p='a') + |SELECT /*+ broadcast(input2) */ input1.col1, input2.col1 + |FROM input1 + |JOIN input2 + |ON input1.col1 = input2.col1 + |""".stripMargin) + checkShuffleAndSort(df0.queryExecution.analyzed, 1, 1) + + val df1 = sql( + s""" + |INSERT INTO TABLE t PARTITION(p='a') + |SELECT /*+ broadcast(input2) */ input1.col1, input1.col2 + |FROM input1 + |LEFT JOIN input2 + |ON input1.col1 = input2.col1 and input1.col2 = input2.col2 + |""".stripMargin) + checkShuffleAndSort(df1.queryExecution.analyzed, 1, 2) + + val df2 = sql( + s""" + |INSERT INTO TABLE t PARTITION(p='a') + |SELECT col1 as c1, count(*) as c2 + |FROM input1 + |GROUP BY col1 + |HAVING count(*) > 0 + |""".stripMargin) + checkShuffleAndSort(df2.queryExecution.analyzed, 1, 1) + + // dynamic partition + val df3 = sql( + s""" + |INSERT INTO TABLE t PARTITION(p) + |SELECT /*+ broadcast(input2) */ input1.col1, input1.col2, input1.col2 + |FROM input1 + |JOIN input2 + |ON input1.col1 = input2.col1 + |""".stripMargin) + checkShuffleAndSort(df3.queryExecution.analyzed, 0, 1) + + // non-deterministic + val df4 = sql( + s""" + |INSERT INTO TABLE t PARTITION(p='a') + |SELECT col1 + rand(), count(*) as c2 + |FROM input1 + |GROUP BY col1 + |""".stripMargin) + checkShuffleAndSort(df4.queryExecution.analyzed, 0, 0) + + // view + val df5 = sql( + s""" + |INSERT INTO TABLE t PARTITION(p='a') + |SELECT * FROM v + |""".stripMargin) + checkShuffleAndSort(df5.queryExecution.analyzed, 1, 1) + } + } + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala new file mode 100644 index 00000000000..957089340ca --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +class WatchDogSuite extends WatchDogSuiteBase {} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala new file mode 100644 index 00000000000..a202e813c5e --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, LogicalPlan} + +import org.apache.kyuubi.sql.KyuubiSQLConf +import org.apache.kyuubi.sql.watchdog.{MaxFileSizeExceedException, MaxPartitionExceedException} + +trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { + override protected def beforeAll(): Unit = { + super.beforeAll() + setupData() + } + + case class LimitAndExpected(limit: Int, expected: Int) + + val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10)) + + private def checkMaxPartition: Unit = { + withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "100") { + checkAnswer(sql("SELECT count(distinct(p)) FROM test"), Row(10) :: Nil) + } + withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "5") { + sql("SELECT * FROM test where p=1").queryExecution.sparkPlan + + sql(s"SELECT * FROM test WHERE p in (${Range(0, 5).toList.mkString(",")})") + .queryExecution.sparkPlan + + intercept[MaxPartitionExceedException]( + sql("SELECT * FROM test where p != 1").queryExecution.sparkPlan) + + intercept[MaxPartitionExceedException]( + sql("SELECT * FROM test").queryExecution.sparkPlan) + + intercept[MaxPartitionExceedException](sql( + s"SELECT * FROM test WHERE p in (${Range(0, 6).toList.mkString(",")})") + .queryExecution.sparkPlan) + } + } + + test("watchdog with scan maxPartitions -- hive") { + Seq("textfile", "parquet").foreach { format => + withTable("test", "temp") { + sql( + s""" + |CREATE TABLE test(i int) + |PARTITIONED BY (p int) + |STORED AS $format""".stripMargin) + spark.range(0, 10, 1).selectExpr("id as col") + .createOrReplaceTempView("temp") + + for (part <- Range(0, 10)) { + sql( + s""" + |INSERT OVERWRITE TABLE test PARTITION (p='$part') + |select col from temp""".stripMargin) + } + checkMaxPartition + } + } + } + + test("watchdog with scan maxPartitions -- data source") { + withTempDir { dir => + withTempView("test") { + spark.range(10).selectExpr("id", "id as p") + .write + .partitionBy("p") + .mode("overwrite") + .save(dir.getCanonicalPath) + spark.read.load(dir.getCanonicalPath).createOrReplaceTempView("test") + checkMaxPartition + } + } + } + + test("test watchdog: simple SELECT STATEMENT") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + List("", "ORDER BY c1", "ORDER BY c2").foreach { sort => + List("", " DISTINCT").foreach { distinct => + assert(sql( + s""" + |SELECT $distinct * + |FROM t1 + |$sort + |""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + List("", "ORDER BY c1", "ORDER BY c2").foreach { sort => + List("", "DISTINCT").foreach { distinct => + assert(sql( + s""" + |SELECT $distinct * + |FROM t1 + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: SELECT ... WITH AGGREGATE STATEMENT ") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + assert(!sql("SELECT count(*) FROM t1") + .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) + + val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") + val havingConditions = List("", "HAVING cnt > 1") + + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |SELECT c1, COUNT(*) as cnt + |FROM t1 + |GROUP BY c1 + |$having + |$sort + |""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |SELECT c1, COUNT(*) as cnt + |FROM t1 + |GROUP BY c1 + |$having + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: SELECT with CTE forceMaxOutputRows") { + // simple CTE + val q1 = + """ + |WITH t2 AS ( + | SELECT * FROM t1 + |) + |""".stripMargin + + // nested CTE + val q2 = + """ + |WITH + | t AS (SELECT * FROM t1), + | t2 AS ( + | WITH t3 AS (SELECT * FROM t1) + | SELECT * FROM t3 + | ) + |""".stripMargin + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + val sorts = List("", "ORDER BY c1", "ORDER BY c2") + + sorts.foreach { sort => + Seq(q1, q2).foreach { withQuery => + assert(sql( + s""" + |$withQuery + |SELECT * FROM t2 + |$sort + |""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + sorts.foreach { sort => + Seq(q1, q2).foreach { withQuery => + assert(sql( + s""" + |$withQuery + |SELECT * FROM t2 + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: SELECT AGGREGATE WITH CTE forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + assert(!sql( + """ + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + | + |SELECT COUNT(*) + |FROM custom_cte + |""".stripMargin).queryExecution + .analyzed.isInstanceOf[GlobalLimit]) + + val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") + val havingConditions = List("", "HAVING cnt > 1") + + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + | + |SELECT c1, COUNT(*) as cnt + |FROM custom_cte + |GROUP BY c1 + |$having + |$sort + |""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + | + |SELECT c1, COUNT(*) as cnt + |FROM custom_cte + |GROUP BY c1 + |$having + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: UNION Statement for forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + List("", "ALL").foreach { x => + assert(sql( + s""" + |SELECT c1, c2 FROM t1 + |UNION $x + |SELECT c1, c2 FROM t2 + |UNION $x + |SELECT c1, c2 FROM t3 + |""".stripMargin) + .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) + } + + val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") + val havingConditions = List("", "HAVING cnt > 1") + + List("", "ALL").foreach { x => + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |SELECT c1, count(c2) as cnt + |FROM t1 + |GROUP BY c1 + |$having + |UNION $x + |SELECT c1, COUNT(c2) as cnt + |FROM t2 + |GROUP BY c1 + |$having + |UNION $x + |SELECT c1, COUNT(c2) as cnt + |FROM t3 + |GROUP BY c1 + |$having + |$sort + |""".stripMargin) + .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) + } + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + assert(sql( + s""" + |SELECT c1, c2 FROM t1 + |UNION + |SELECT c1, c2 FROM t2 + |UNION + |SELECT c1, c2 FROM t3 + |LIMIT $limit + |""".stripMargin) + .queryExecution.optimizedPlan.maxRows.contains(expected)) + } + } + } + + test("test watchdog: Select View Statement for forceMaxOutputRows") { + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "3") { + withTable("tmp_table", "tmp_union") { + withView("tmp_view", "tmp_view2") { + sql(s"create table tmp_table (a int, b int)") + sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") + sql(s"create table tmp_union (a int, b int)") + sql(s"insert into tmp_union values (6,60),(7,70),(8,80),(9,90),(10,100)") + sql(s"create view tmp_view2 as select * from tmp_union") + assert(!sql( + s""" + |CREATE VIEW tmp_view + |as + |SELECT * FROM + |tmp_table + |""".stripMargin) + .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) + + assert(sql( + s""" + |SELECT * FROM + |tmp_view + |""".stripMargin) + .queryExecution.optimizedPlan.maxRows.contains(3)) + + assert(sql( + s""" + |SELECT * FROM + |tmp_view + |limit 11 + |""".stripMargin) + .queryExecution.optimizedPlan.maxRows.contains(3)) + + assert(sql( + s""" + |SELECT * FROM + |(select * from tmp_view + |UNION + |select * from tmp_view2) + |ORDER BY a + |DESC + |""".stripMargin) + .collect().head.get(0) === 10) + } + } + } + } + + test("test watchdog: Insert Statement for forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + withTable("tmp_table", "tmp_insert") { + spark.sql(s"create table tmp_table (a int, b int)") + spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") + val multiInsertTableName1: String = "tmp_tbl1" + val multiInsertTableName2: String = "tmp_tbl2" + sql(s"drop table if exists $multiInsertTableName1") + sql(s"drop table if exists $multiInsertTableName2") + sql(s"create table $multiInsertTableName1 like tmp_table") + sql(s"create table $multiInsertTableName2 like tmp_table") + assert(!sql( + s""" + |FROM tmp_table + |insert into $multiInsertTableName1 select * limit 2 + |insert into $multiInsertTableName2 select * + |""".stripMargin) + .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) + } + } + } + + test("test watchdog: Distribute by for forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + withTable("tmp_table") { + spark.sql(s"create table tmp_table (a int, b int)") + spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)") + assert(sql( + s""" + |SELECT * + |FROM tmp_table + |DISTRIBUTE BY a + |""".stripMargin) + .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) + } + } + } + + test("test watchdog: Subquery for forceMaxOutputRows") { + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "1") { + withTable("tmp_table1") { + sql("CREATE TABLE spark_catalog.`default`.tmp_table1(KEY INT, VALUE STRING) USING PARQUET") + sql("INSERT INTO TABLE spark_catalog.`default`.tmp_table1 " + + "VALUES (1, 'aa'),(2,'bb'),(3, 'cc'),(4,'aa'),(5,'cc'),(6, 'aa')") + assert( + sql("select * from tmp_table1").queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) + val testSqlText = + """ + |select count(*) + |from tmp_table1 + |where tmp_table1.key in ( + |select distinct tmp_table1.key + |from tmp_table1 + |where tmp_table1.value = "aa" + |) + |""".stripMargin + val plan = sql(testSqlText).queryExecution.optimizedPlan + assert(!findGlobalLimit(plan)) + checkAnswer(sql(testSqlText), Row(3) :: Nil) + } + + def findGlobalLimit(plan: LogicalPlan): Boolean = plan match { + case _: GlobalLimit => true + case p if p.children.isEmpty => false + case p => p.children.exists(findGlobalLimit) + } + + } + } + + test("test watchdog: Join for forceMaxOutputRows") { + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "1") { + withTable("tmp_table1", "tmp_table2") { + sql("CREATE TABLE spark_catalog.`default`.tmp_table1(KEY INT, VALUE STRING) USING PARQUET") + sql("INSERT INTO TABLE spark_catalog.`default`.tmp_table1 " + + "VALUES (1, 'aa'),(2,'bb'),(3, 'cc'),(4,'aa'),(5,'cc'),(6, 'aa')") + sql("CREATE TABLE spark_catalog.`default`.tmp_table2(KEY INT, VALUE STRING) USING PARQUET") + sql("INSERT INTO TABLE spark_catalog.`default`.tmp_table2 " + + "VALUES (1, 'aa'),(2,'bb'),(3, 'cc'),(4,'aa'),(5,'cc'),(6, 'aa')") + val testSqlText = + """ + |select a.*,b.* + |from tmp_table1 a + |join + |tmp_table2 b + |on a.KEY = b.KEY + |""".stripMargin + val plan = sql(testSqlText).queryExecution.optimizedPlan + assert(findGlobalLimit(plan)) + } + + def findGlobalLimit(plan: LogicalPlan): Boolean = plan match { + case _: GlobalLimit => true + case p if p.children.isEmpty => false + case p => p.children.exists(findGlobalLimit) + } + } + } + + private def checkMaxFileSize(tableSize: Long, nonPartTableSize: Long): Unit = { + withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_FILE_SIZE.key -> tableSize.toString) { + checkAnswer(sql("SELECT count(distinct(p)) FROM test"), Row(10) :: Nil) + } + + withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_FILE_SIZE.key -> (tableSize / 2).toString) { + sql("SELECT * FROM test where p=1").queryExecution.sparkPlan + + sql(s"SELECT * FROM test WHERE p in (${Range(0, 3).toList.mkString(",")})") + .queryExecution.sparkPlan + + intercept[MaxFileSizeExceedException]( + sql("SELECT * FROM test where p != 1").queryExecution.sparkPlan) + + intercept[MaxFileSizeExceedException]( + sql("SELECT * FROM test").queryExecution.sparkPlan) + + intercept[MaxFileSizeExceedException](sql( + s"SELECT * FROM test WHERE p in (${Range(0, 6).toList.mkString(",")})") + .queryExecution.sparkPlan) + } + + withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_FILE_SIZE.key -> nonPartTableSize.toString) { + checkAnswer(sql("SELECT count(*) FROM test_non_part"), Row(10000) :: Nil) + } + + withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_FILE_SIZE.key -> (nonPartTableSize - 1).toString) { + intercept[MaxFileSizeExceedException]( + sql("SELECT * FROM test_non_part").queryExecution.sparkPlan) + } + } + + test("watchdog with scan maxFileSize -- hive") { + Seq(false).foreach { convertMetastoreParquet => + withTable("test", "test_non_part", "temp") { + spark.range(10000).selectExpr("id as col") + .createOrReplaceTempView("temp") + + // partitioned table + sql( + s""" + |CREATE TABLE test(i int) + |PARTITIONED BY (p int) + |STORED AS parquet""".stripMargin) + for (part <- Range(0, 10)) { + sql( + s""" + |INSERT OVERWRITE TABLE test PARTITION (p='$part') + |select col from temp""".stripMargin) + } + + val tablePath = new File(spark.sessionState.catalog.externalCatalog + .getTable("default", "test").location) + val tableSize = FileUtils.listFiles(tablePath, Array("parquet"), true).asScala + .map(_.length()).sum + assert(tableSize > 0) + + // non-partitioned table + sql( + s""" + |CREATE TABLE test_non_part(i int) + |STORED AS parquet""".stripMargin) + sql( + s""" + |INSERT OVERWRITE TABLE test_non_part + |select col from temp""".stripMargin) + sql("ANALYZE TABLE test_non_part COMPUTE STATISTICS") + + val nonPartTablePath = new File(spark.sessionState.catalog.externalCatalog + .getTable("default", "test_non_part").location) + val nonPartTableSize = FileUtils.listFiles(nonPartTablePath, Array("parquet"), true).asScala + .map(_.length()).sum + assert(nonPartTableSize > 0) + + // check + withSQLConf("spark.sql.hive.convertMetastoreParquet" -> convertMetastoreParquet.toString) { + checkMaxFileSize(tableSize, nonPartTableSize) + } + } + } + } + + test("watchdog with scan maxFileSize -- data source") { + withTempDir { dir => + withTempView("test", "test_non_part") { + // partitioned table + val tablePath = new File(dir, "test") + spark.range(10).selectExpr("id", "id as p") + .write + .partitionBy("p") + .mode("overwrite") + .parquet(tablePath.getCanonicalPath) + spark.read.load(tablePath.getCanonicalPath).createOrReplaceTempView("test") + + val tableSize = FileUtils.listFiles(tablePath, Array("parquet"), true).asScala + .map(_.length()).sum + assert(tableSize > 0) + + // non-partitioned table + val nonPartTablePath = new File(dir, "test_non_part") + spark.range(10000).selectExpr("id", "id as p") + .write + .mode("overwrite") + .parquet(nonPartTablePath.getCanonicalPath) + spark.read.load(nonPartTablePath.getCanonicalPath).createOrReplaceTempView("test_non_part") + + val nonPartTableSize = FileUtils.listFiles(nonPartTablePath, Array("parquet"), true).asScala + .map(_.length()).sum + assert(tableSize > 0) + + // check + checkMaxFileSize(tableSize, nonPartTableSize) + } + } + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderCoreBenchmark.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderCoreBenchmark.scala new file mode 100644 index 00000000000..9b1614fce31 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderCoreBenchmark.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.benchmark.KyuubiBenchmarkBase +import org.apache.spark.sql.internal.StaticSQLConf + +import org.apache.kyuubi.sql.zorder.ZorderBytesUtils + +/** + * Benchmark to measure performance with zorder core. + * + * {{{ + * RUN_BENCHMARK=1 ./build/mvn clean test \ + * -pl extensions/spark/kyuubi-extension-spark-3-1 -am \ + * -Pspark-3.1,kyuubi-extension-spark-3-1 \ + * -Dtest=none -DwildcardSuites=org.apache.spark.sql.ZorderCoreBenchmark + * }}} + */ +class ZorderCoreBenchmark extends KyuubiSparkSQLExtensionTest with KyuubiBenchmarkBase { + private val runBenchmark = sys.env.contains("RUN_BENCHMARK") + private val numRows = 1 * 1000 * 1000 + + private def randomInt(numColumns: Int): Seq[Array[Any]] = { + (1 to numRows).map { l => + val arr = new Array[Any](numColumns) + (0 until numColumns).foreach(col => arr(col) = l) + arr + } + } + + private def randomLong(numColumns: Int): Seq[Array[Any]] = { + (1 to numRows).map { l => + val arr = new Array[Any](numColumns) + (0 until numColumns).foreach(col => arr(col) = l.toLong) + arr + } + } + + private def interleaveMultiByteArrayBenchmark(): Unit = { + val benchmark = + new Benchmark(s"$numRows rows zorder core benchmark", numRows, output = output) + benchmark.addCase("2 int columns benchmark", 3) { _ => + randomInt(2).foreach(ZorderBytesUtils.interleaveBits) + } + + benchmark.addCase("3 int columns benchmark", 3) { _ => + randomInt(3).foreach(ZorderBytesUtils.interleaveBits) + } + + benchmark.addCase("4 int columns benchmark", 3) { _ => + randomInt(4).foreach(ZorderBytesUtils.interleaveBits) + } + + benchmark.addCase("2 long columns benchmark", 3) { _ => + randomLong(2).foreach(ZorderBytesUtils.interleaveBits) + } + + benchmark.addCase("3 long columns benchmark", 3) { _ => + randomLong(3).foreach(ZorderBytesUtils.interleaveBits) + } + + benchmark.addCase("4 long columns benchmark", 3) { _ => + randomLong(4).foreach(ZorderBytesUtils.interleaveBits) + } + + benchmark.run() + } + + private def paddingTo8ByteBenchmark() { + val iterations = 10 * 1000 * 1000 + + val b2 = Array('a'.toByte, 'b'.toByte) + val benchmark = + new Benchmark(s"$iterations iterations paddingTo8Byte benchmark", iterations, output = output) + benchmark.addCase("2 length benchmark", 3) { _ => + (1 to iterations).foreach(_ => ZorderBytesUtils.paddingTo8Byte(b2)) + } + + val b16 = Array.tabulate(16) { i => i.toByte } + benchmark.addCase("16 length benchmark", 3) { _ => + (1 to iterations).foreach(_ => ZorderBytesUtils.paddingTo8Byte(b16)) + } + + benchmark.run() + } + + test("zorder core benchmark") { + assume(runBenchmark) + + withHeader { + interleaveMultiByteArrayBenchmark() + paddingTo8ByteBenchmark() + } + } + + override def sparkConf(): SparkConf = { + super.sparkConf().remove(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key) + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuite.scala new file mode 100644 index 00000000000..c2fa1619707 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.{RebalancePartitions, Sort} +import org.apache.spark.sql.internal.SQLConf + +import org.apache.kyuubi.sql.{KyuubiSQLConf, SparkKyuubiSparkSQLParser} +import org.apache.kyuubi.sql.zorder.Zorder + +trait ZorderSuiteSpark extends ZorderSuiteBase { + + test("Add rebalance before zorder") { + Seq("true" -> false, "false" -> true).foreach { case (useOriginalOrdering, zorder) => + withSQLConf( + KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED.key -> "false", + KyuubiSQLConf.REBALANCE_BEFORE_ZORDER.key -> "true", + KyuubiSQLConf.REBALANCE_ZORDER_COLUMNS_ENABLED.key -> "true", + KyuubiSQLConf.ZORDER_USING_ORIGINAL_ORDERING_ENABLED.key -> useOriginalOrdering) { + withTable("t") { + sql( + """ + |CREATE TABLE t (c1 int, c2 string) PARTITIONED BY (d string) + | TBLPROPERTIES ( + |'kyuubi.zorder.enabled'= 'true', + |'kyuubi.zorder.cols'= 'c1,C2') + |""".stripMargin) + val p = sql("INSERT INTO TABLE t PARTITION(d='a') SELECT * FROM VALUES(1,'a')") + .queryExecution.analyzed + assert(p.collect { + case sort: Sort + if !sort.global && + ((sort.order.exists(_.child.isInstanceOf[Zorder]) && zorder) || + (!sort.order.exists(_.child.isInstanceOf[Zorder]) && !zorder)) => sort + }.size == 1) + assert(p.collect { + case rebalance: RebalancePartitions + if rebalance.references.map(_.name).exists(_.equals("c1")) => rebalance + }.size == 1) + + val p2 = sql("INSERT INTO TABLE t PARTITION(d) SELECT * FROM VALUES(1,'a','b')") + .queryExecution.analyzed + assert(p2.collect { + case sort: Sort + if (!sort.global && Seq("c1", "c2", "d").forall(x => + sort.references.map(_.name).exists(_.equals(x)))) && + ((sort.order.exists(_.child.isInstanceOf[Zorder]) && zorder) || + (!sort.order.exists(_.child.isInstanceOf[Zorder]) && !zorder)) => sort + }.size == 1) + assert(p2.collect { + case rebalance: RebalancePartitions + if Seq("c1", "c2", "d").forall(x => + rebalance.references.map(_.name).exists(_.equals(x))) => rebalance + }.size == 1) + } + } + } + } + + test("Two phase rebalance before Z-Order") { + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.CollapseRepartition", + KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED.key -> "false", + KyuubiSQLConf.REBALANCE_BEFORE_ZORDER.key -> "true", + KyuubiSQLConf.TWO_PHASE_REBALANCE_BEFORE_ZORDER.key -> "true", + KyuubiSQLConf.REBALANCE_ZORDER_COLUMNS_ENABLED.key -> "true") { + withTable("t") { + sql( + """ + |CREATE TABLE t (c1 int) PARTITIONED BY (d string) + | TBLPROPERTIES ( + |'kyuubi.zorder.enabled'= 'true', + |'kyuubi.zorder.cols'= 'c1') + |""".stripMargin) + val p = sql("INSERT INTO TABLE t PARTITION(d) SELECT * FROM VALUES(1,'a')") + val rebalance = p.queryExecution.optimizedPlan.innerChildren + .flatMap(_.collect { case r: RebalancePartitions => r }) + assert(rebalance.size == 2) + assert(rebalance.head.partitionExpressions.flatMap(_.references.map(_.name)) + .contains("d")) + assert(rebalance.head.partitionExpressions.flatMap(_.references.map(_.name)) + .contains("c1")) + + assert(rebalance(1).partitionExpressions.flatMap(_.references.map(_.name)) + .contains("d")) + assert(!rebalance(1).partitionExpressions.flatMap(_.references.map(_.name)) + .contains("c1")) + } + } + } +} + +trait ParserSuite { self: ZorderSuiteBase => + override def createParser: ParserInterface = { + new SparkKyuubiSparkSQLParser(spark.sessionState.sqlParser) + } +} + +class ZorderWithCodegenEnabledSuite + extends ZorderWithCodegenEnabledSuiteBase + with ZorderSuiteSpark + with ParserSuite {} +class ZorderWithCodegenDisabledSuite + extends ZorderWithCodegenDisabledSuiteBase + with ZorderSuiteSpark + with ParserSuite {} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuiteBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuiteBase.scala new file mode 100644 index 00000000000..2d3eec95722 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuiteBase.scala @@ -0,0 +1,768 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, EqualTo, Expression, ExpressionEvalHelper, Literal, NullsLast, SortOrder} +import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Project, Sort} +import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.execution.InsertIntoHiveTable +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.types._ + +import org.apache.kyuubi.sql.{KyuubiSQLConf, KyuubiSQLExtensionException} +import org.apache.kyuubi.sql.zorder.{OptimizeZorderCommandBase, OptimizeZorderStatement, Zorder, ZorderBytesUtils} + +trait ZorderSuiteBase extends KyuubiSparkSQLExtensionTest with ExpressionEvalHelper { + override def sparkConf(): SparkConf = { + super.sparkConf() + .set( + StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, + "org.apache.kyuubi.sql.KyuubiSparkSQLCommonExtension") + } + + test("optimize unpartitioned table") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + withTable("up") { + sql(s"DROP TABLE IF EXISTS up") + + val target = Seq( + Seq(0, 0), + Seq(1, 0), + Seq(0, 1), + Seq(1, 1), + Seq(2, 0), + Seq(3, 0), + Seq(2, 1), + Seq(3, 1), + Seq(0, 2), + Seq(1, 2), + Seq(0, 3), + Seq(1, 3), + Seq(2, 2), + Seq(3, 2), + Seq(2, 3), + Seq(3, 3)) + sql(s"CREATE TABLE up (c1 INT, c2 INT, c3 INT)") + sql(s"INSERT INTO TABLE up VALUES" + + "(0,0,2),(0,1,2),(0,2,1),(0,3,3)," + + "(1,0,4),(1,1,2),(1,2,1),(1,3,3)," + + "(2,0,2),(2,1,1),(2,2,5),(2,3,5)," + + "(3,0,3),(3,1,4),(3,2,9),(3,3,0)") + + val e = intercept[KyuubiSQLExtensionException] { + sql("OPTIMIZE up WHERE c1 > 1 ZORDER BY c1, c2") + } + assert(e.getMessage == "Filters are only supported for partitioned table") + + sql("OPTIMIZE up ZORDER BY c1, c2") + val res = sql("SELECT c1, c2 FROM up").collect() + + assert(res.length == 16) + + for (i <- target.indices) { + val t = target(i) + val r = res(i) + assert(t(0) == r.getInt(0)) + assert(t(1) == r.getInt(1)) + } + } + } + } + + test("optimize partitioned table") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + withTable("p") { + sql("DROP TABLE IF EXISTS p") + + val target = Seq( + Seq(0, 0), + Seq(1, 0), + Seq(0, 1), + Seq(1, 1), + Seq(2, 0), + Seq(3, 0), + Seq(2, 1), + Seq(3, 1), + Seq(0, 2), + Seq(1, 2), + Seq(0, 3), + Seq(1, 3), + Seq(2, 2), + Seq(3, 2), + Seq(2, 3), + Seq(3, 3)) + + sql(s"CREATE TABLE p (c1 INT, c2 INT, c3 INT) PARTITIONED BY (id INT)") + sql(s"ALTER TABLE p ADD PARTITION (id = 1)") + sql(s"ALTER TABLE p ADD PARTITION (id = 2)") + sql(s"INSERT INTO TABLE p PARTITION (id = 1) VALUES" + + "(0,0,2),(0,1,2),(0,2,1),(0,3,3)," + + "(1,0,4),(1,1,2),(1,2,1),(1,3,3)," + + "(2,0,2),(2,1,1),(2,2,5),(2,3,5)," + + "(3,0,3),(3,1,4),(3,2,9),(3,3,0)") + sql(s"INSERT INTO TABLE p PARTITION (id = 2) VALUES" + + "(0,0,2),(0,1,2),(0,2,1),(0,3,3)," + + "(1,0,4),(1,1,2),(1,2,1),(1,3,3)," + + "(2,0,2),(2,1,1),(2,2,5),(2,3,5)," + + "(3,0,3),(3,1,4),(3,2,9),(3,3,0)") + + sql(s"OPTIMIZE p ZORDER BY c1, c2") + + val res1 = sql(s"SELECT c1, c2 FROM p WHERE id = 1").collect() + val res2 = sql(s"SELECT c1, c2 FROM p WHERE id = 2").collect() + + assert(res1.length == 16) + assert(res2.length == 16) + + for (i <- target.indices) { + val t = target(i) + val r1 = res1(i) + assert(t(0) == r1.getInt(0)) + assert(t(1) == r1.getInt(1)) + + val r2 = res2(i) + assert(t(0) == r2.getInt(0)) + assert(t(1) == r2.getInt(1)) + } + } + } + } + + test("optimize partitioned table with filters") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + withTable("p") { + sql("DROP TABLE IF EXISTS p") + + val target1 = Seq( + Seq(0, 0), + Seq(1, 0), + Seq(0, 1), + Seq(1, 1), + Seq(2, 0), + Seq(3, 0), + Seq(2, 1), + Seq(3, 1), + Seq(0, 2), + Seq(1, 2), + Seq(0, 3), + Seq(1, 3), + Seq(2, 2), + Seq(3, 2), + Seq(2, 3), + Seq(3, 3)) + val target2 = Seq( + Seq(0, 0), + Seq(0, 1), + Seq(0, 2), + Seq(0, 3), + Seq(1, 0), + Seq(1, 1), + Seq(1, 2), + Seq(1, 3), + Seq(2, 0), + Seq(2, 1), + Seq(2, 2), + Seq(2, 3), + Seq(3, 0), + Seq(3, 1), + Seq(3, 2), + Seq(3, 3)) + sql(s"CREATE TABLE p (c1 INT, c2 INT, c3 INT) PARTITIONED BY (id INT)") + sql(s"ALTER TABLE p ADD PARTITION (id = 1)") + sql(s"ALTER TABLE p ADD PARTITION (id = 2)") + sql(s"INSERT INTO TABLE p PARTITION (id = 1) VALUES" + + "(0,0,2),(0,1,2),(0,2,1),(0,3,3)," + + "(1,0,4),(1,1,2),(1,2,1),(1,3,3)," + + "(2,0,2),(2,1,1),(2,2,5),(2,3,5)," + + "(3,0,3),(3,1,4),(3,2,9),(3,3,0)") + sql(s"INSERT INTO TABLE p PARTITION (id = 2) VALUES" + + "(0,0,2),(0,1,2),(0,2,1),(0,3,3)," + + "(1,0,4),(1,1,2),(1,2,1),(1,3,3)," + + "(2,0,2),(2,1,1),(2,2,5),(2,3,5)," + + "(3,0,3),(3,1,4),(3,2,9),(3,3,0)") + + val e = intercept[KyuubiSQLExtensionException]( + sql(s"OPTIMIZE p WHERE id = 1 AND c1 > 1 ZORDER BY c1, c2")) + assert(e.getMessage == "Only partition column filters are allowed") + + sql(s"OPTIMIZE p WHERE id = 1 ZORDER BY c1, c2") + + val res1 = sql(s"SELECT c1, c2 FROM p WHERE id = 1").collect() + val res2 = sql(s"SELECT c1, c2 FROM p WHERE id = 2").collect() + + assert(res1.length == 16) + assert(res2.length == 16) + + for (i <- target1.indices) { + val t1 = target1(i) + val r1 = res1(i) + assert(t1(0) == r1.getInt(0)) + assert(t1(1) == r1.getInt(1)) + + val t2 = target2(i) + val r2 = res2(i) + assert(t2(0) == r2.getInt(0)) + assert(t2(1) == r2.getInt(1)) + } + } + } + } + + test("optimize zorder with datasource table") { + // TODO remove this if we support datasource table + withTable("t") { + sql("CREATE TABLE t (c1 int, c2 int) USING PARQUET") + val msg = intercept[KyuubiSQLExtensionException] { + sql("OPTIMIZE t ZORDER BY c1, c2") + }.getMessage + assert(msg.contains("only support hive table")) + } + } + + private def checkZorderTable( + enabled: Boolean, + cols: String, + planHasRepartition: Boolean, + resHasSort: Boolean): Unit = { + def checkSort(plan: LogicalPlan): Unit = { + assert(plan.isInstanceOf[Sort] === resHasSort) + plan match { + case sort: Sort => + val colArr = cols.split(",") + val refs = + if (colArr.length == 1) { + sort.order.head + .child.asInstanceOf[AttributeReference] :: Nil + } else { + sort.order.head + .child.asInstanceOf[Zorder].children.map(_.references.head) + } + assert(refs.size === colArr.size) + refs.zip(colArr).foreach { case (ref, col) => + assert(ref.name === col.trim) + } + case _ => + } + } + + val repartition = + if (planHasRepartition) { + "/*+ repartition */" + } else { + "" + } + withSQLConf("spark.sql.shuffle.partitions" -> "1") { + // hive + withSQLConf("spark.sql.hive.convertMetastoreParquet" -> "false") { + withTable("zorder_t1", "zorder_t2_true", "zorder_t2_false") { + sql( + s""" + |CREATE TABLE zorder_t1 (c1 int, c2 string, c3 long, c4 double) STORED AS PARQUET + |TBLPROPERTIES ( + | 'kyuubi.zorder.enabled' = '$enabled', + | 'kyuubi.zorder.cols' = '$cols') + |""".stripMargin) + val df1 = sql(s""" + |INSERT INTO TABLE zorder_t1 + |SELECT $repartition * FROM VALUES(1,'a',2,4D),(2,'b',3,6D) + |""".stripMargin) + assert(df1.queryExecution.analyzed.isInstanceOf[InsertIntoHiveTable]) + checkSort(df1.queryExecution.analyzed.children.head) + + Seq("true", "false").foreach { optimized => + withSQLConf( + "spark.sql.hive.convertMetastoreCtas" -> optimized, + "spark.sql.hive.convertMetastoreParquet" -> optimized) { + + withListener( + s""" + |CREATE TABLE zorder_t2_$optimized STORED AS PARQUET + |TBLPROPERTIES ( + | 'kyuubi.zorder.enabled' = '$enabled', + | 'kyuubi.zorder.cols' = '$cols') + | + |SELECT $repartition * FROM + |VALUES(1,'a',2,4D),(2,'b',3,6D) AS t(c1 ,c2 , c3, c4) + |""".stripMargin) { write => + if (optimized.toBoolean) { + assert(write.isInstanceOf[InsertIntoHadoopFsRelationCommand]) + } else { + assert(write.isInstanceOf[InsertIntoHiveTable]) + } + checkSort(write.query) + } + } + } + } + } + + // datasource + withTable("zorder_t3", "zorder_t4") { + sql( + s""" + |CREATE TABLE zorder_t3 (c1 int, c2 string, c3 long, c4 double) USING PARQUET + |TBLPROPERTIES ( + | 'kyuubi.zorder.enabled' = '$enabled', + | 'kyuubi.zorder.cols' = '$cols') + |""".stripMargin) + val df1 = sql(s""" + |INSERT INTO TABLE zorder_t3 + |SELECT $repartition * FROM VALUES(1,'a',2,4D),(2,'b',3,6D) + |""".stripMargin) + assert(df1.queryExecution.analyzed.isInstanceOf[InsertIntoHadoopFsRelationCommand]) + checkSort(df1.queryExecution.analyzed.children.head) + + withListener( + s""" + |CREATE TABLE zorder_t4 USING PARQUET + |TBLPROPERTIES ( + | 'kyuubi.zorder.enabled' = '$enabled', + | 'kyuubi.zorder.cols' = '$cols') + | + |SELECT $repartition * FROM + |VALUES(1,'a',2,4D),(2,'b',3,6D) AS t(c1 ,c2 , c3, c4) + |""".stripMargin) { write => + assert(write.isInstanceOf[InsertIntoHadoopFsRelationCommand]) + checkSort(write.query) + } + } + } + } + + test("Support insert zorder by table properties") { + withSQLConf(KyuubiSQLConf.INSERT_ZORDER_BEFORE_WRITING.key -> "false") { + checkZorderTable(true, "c1", false, false) + checkZorderTable(false, "c1", false, false) + } + withSQLConf(KyuubiSQLConf.INSERT_ZORDER_BEFORE_WRITING.key -> "true") { + checkZorderTable(true, "", false, false) + checkZorderTable(true, "c5", false, false) + checkZorderTable(true, "c1,c5", false, false) + checkZorderTable(false, "c3", false, false) + checkZorderTable(true, "c3", true, false) + checkZorderTable(true, "c3", false, true) + checkZorderTable(true, "c2,c4", false, true) + checkZorderTable(true, "c4, c2, c1, c3", false, true) + } + } + + test("zorder: check unsupported data type") { + def checkZorderPlan(zorder: Expression): Unit = { + val msg = intercept[AnalysisException] { + val plan = Project(Seq(Alias(zorder, "c")()), OneRowRelation()) + spark.sessionState.analyzer.checkAnalysis(plan) + }.getMessage + // before Spark 3.2.0 the null type catalog string is null, after Spark 3.2.0 it's void + // see https://github.com/apache/spark/pull/33437 + assert(msg.contains("Unsupported z-order type:") && + (msg.contains("null") || msg.contains("void"))) + } + + checkZorderPlan(Zorder(Seq(Literal(null, NullType)))) + checkZorderPlan(Zorder(Seq(Literal(1, IntegerType), Literal(null, NullType)))) + } + + test("zorder: check supported data type") { + val children = Seq( + Literal.create(false, BooleanType), + Literal.create(null, BooleanType), + Literal.create(1.toByte, ByteType), + Literal.create(null, ByteType), + Literal.create(1.toShort, ShortType), + Literal.create(null, ShortType), + Literal.create(1, IntegerType), + Literal.create(null, IntegerType), + Literal.create(1L, LongType), + Literal.create(null, LongType), + Literal.create(1f, FloatType), + Literal.create(null, FloatType), + Literal.create(1d, DoubleType), + Literal.create(null, DoubleType), + Literal.create("1", StringType), + Literal.create(null, StringType), + Literal.create(1L, TimestampType), + Literal.create(null, TimestampType), + Literal.create(1, DateType), + Literal.create(null, DateType), + Literal.create(BigDecimal(1, 1), DecimalType(1, 1)), + Literal.create(null, DecimalType(1, 1))) + val zorder = Zorder(children) + val plan = Project(Seq(Alias(zorder, "c")()), OneRowRelation()) + spark.sessionState.analyzer.checkAnalysis(plan) + assert(zorder.foldable) + +// // scalastyle:off +// val resultGen = org.apache.commons.codec.binary.Hex.encodeHex( +// zorder.eval(InternalRow.fromSeq(children)).asInstanceOf[Array[Byte]], false) +// resultGen.grouped(2).zipWithIndex.foreach { case (char, i) => +// print("0x" + char(0) + char(1) + ", ") +// if ((i + 1) % 10 == 0) { +// println() +// } +// } +// // scalastyle:on + + val expected = Array( + 0xFB, 0xEA, 0xAA, 0xBA, 0xAE, 0xAB, 0xAA, 0xEA, 0xBA, 0xAE, 0xAB, 0xAA, 0xEA, 0xBA, 0xA6, + 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, + 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xBA, 0xBB, 0xAA, 0xAA, 0xAA, + 0xBA, 0xAA, 0xBA, 0xAA, 0xBA, 0xAA, 0xBA, 0xAA, 0xBA, 0xAA, 0xBA, 0xAA, 0x9A, 0xAA, 0xAA, + 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xEA, + 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, + 0xAA, 0xAA, 0xBE, 0xAA, 0xAA, 0x8A, 0xBA, 0xAA, 0x2A, 0xEA, 0xA8, 0xAA, 0xAA, 0xA2, 0xAA, + 0xAA, 0x8A, 0xAA, 0xAA, 0x2F, 0xEB, 0xFE) + .map(_.toByte) + checkEvaluation(zorder, expected, InternalRow.fromSeq(children)) + } + + private def checkSort(input: DataFrame, expected: Seq[Row], dataType: Array[DataType]): Unit = { + withTempDir { dir => + input.repartition(3).write.mode("overwrite").format("parquet").save(dir.getCanonicalPath) + val df = spark.read.format("parquet") + .load(dir.getCanonicalPath) + .repartition(1) + assert(df.schema.fields.map(_.dataType).sameElements(dataType)) + val exprs = Seq("c1", "c2").map(col).map(_.expr) + val sortOrder = SortOrder(Zorder(exprs), Ascending, NullsLast, Seq.empty) + val zorderSort = Sort(Seq(sortOrder), true, df.logicalPlan) + val result = Dataset.ofRows(spark, zorderSort) + checkAnswer(result, expected) + } + } + + test("sort with zorder -- boolean column") { + val schema = StructType(StructField("c1", BooleanType) :: StructField("c2", BooleanType) :: Nil) + val nonNullDF = spark.createDataFrame( + spark.sparkContext.parallelize( + Seq(Row(false, false), Row(false, true), Row(true, false), Row(true, true))), + schema) + val expected = + Row(false, false) :: Row(true, false) :: Row(false, true) :: Row(true, true) :: Nil + checkSort(nonNullDF, expected, Array(BooleanType, BooleanType)) + val df = spark.createDataFrame( + spark.sparkContext.parallelize( + Seq(Row(false, false), Row(false, null), Row(null, false), Row(null, null))), + schema) + val expected2 = + Row(false, false) :: Row(null, false) :: Row(false, null) :: Row(null, null) :: Nil + checkSort(df, expected2, Array(BooleanType, BooleanType)) + } + + test("sort with zorder -- int column") { + // TODO: add more datatype unit test + val session = spark + import session.implicits._ + // generate 4 * 4 matrix + val len = 3 + val input = spark.range(len + 1).selectExpr("cast(id as int) as c1") + .select($"c1", explode(sequence(lit(0), lit(len))) as "c2") + val expected = + Row(0, 0) :: Row(1, 0) :: Row(0, 1) :: Row(1, 1) :: + Row(2, 0) :: Row(3, 0) :: Row(2, 1) :: Row(3, 1) :: + Row(0, 2) :: Row(1, 2) :: Row(0, 3) :: Row(1, 3) :: + Row(2, 2) :: Row(3, 2) :: Row(2, 3) :: Row(3, 3) :: Nil + checkSort(input, expected, Array(IntegerType, IntegerType)) + + // contains null value case. + val nullDF = spark.range(1).selectExpr("cast(null as int) as c1") + val input2 = spark.range(len).selectExpr("cast(id as int) as c1") + .union(nullDF) + .select( + $"c1", + explode(concat(sequence(lit(0), lit(len - 1)), array(lit(null)))) as "c2") + val expected2 = Row(0, 0) :: Row(1, 0) :: Row(0, 1) :: Row(1, 1) :: + Row(2, 0) :: Row(2, 1) :: Row(0, 2) :: Row(1, 2) :: + Row(2, 2) :: Row(null, 0) :: Row(null, 1) :: Row(null, 2) :: + Row(0, null) :: Row(1, null) :: Row(2, null) :: Row(null, null) :: Nil + checkSort(input2, expected2, Array(IntegerType, IntegerType)) + } + + test("sort with zorder -- string column") { + val schema = StructType(StructField("c1", StringType) :: StructField("c2", StringType) :: Nil) + val rdd = spark.sparkContext.parallelize(Seq( + Row("a", "a"), + Row("a", "b"), + Row("a", "c"), + Row("a", "d"), + Row("b", "a"), + Row("b", "b"), + Row("b", "c"), + Row("b", "d"), + Row("c", "a"), + Row("c", "b"), + Row("c", "c"), + Row("c", "d"), + Row("d", "a"), + Row("d", "b"), + Row("d", "c"), + Row("d", "d"))) + val input = spark.createDataFrame(rdd, schema) + val expected = Row("a", "a") :: Row("b", "a") :: Row("c", "a") :: Row("a", "b") :: + Row("a", "c") :: Row("b", "b") :: Row("c", "b") :: Row("b", "c") :: + Row("c", "c") :: Row("d", "a") :: Row("d", "b") :: Row("d", "c") :: + Row("a", "d") :: Row("b", "d") :: Row("c", "d") :: Row("d", "d") :: Nil + checkSort(input, expected, Array(StringType, StringType)) + + val rdd2 = spark.sparkContext.parallelize(Seq( + Row(null, "a"), + Row("a", "b"), + Row("a", "c"), + Row("a", null), + Row("b", "a"), + Row(null, "b"), + Row("b", null), + Row("b", "d"), + Row("c", "a"), + Row("c", null), + Row(null, "c"), + Row("c", "d"), + Row("d", null), + Row("d", "b"), + Row("d", "c"), + Row(null, "d"), + Row(null, null))) + val input2 = spark.createDataFrame(rdd2, schema) + val expected2 = Row("b", "a") :: Row("c", "a") :: Row("a", "b") :: Row("a", "c") :: + Row("d", "b") :: Row("d", "c") :: Row("b", "d") :: Row("c", "d") :: + Row(null, "a") :: Row(null, "b") :: Row(null, "c") :: Row(null, "d") :: + Row("a", null) :: Row("b", null) :: Row("c", null) :: Row("d", null) :: + Row(null, null) :: Nil + checkSort(input2, expected2, Array(StringType, StringType)) + } + + test("test special value of short int long type") { + val df1 = spark.createDataFrame(Seq( + (-1, -1L), + (Int.MinValue, Int.MinValue.toLong), + (1, 1L), + (Int.MaxValue - 1, Int.MaxValue.toLong), + (Int.MaxValue - 1, Int.MaxValue.toLong - 1), + (Int.MaxValue, Int.MaxValue.toLong + 1), + (Int.MaxValue, Int.MaxValue.toLong))).toDF("c1", "c2") + val expected1 = + Row(Int.MinValue, Int.MinValue.toLong) :: + Row(-1, -1L) :: + Row(1, 1L) :: + Row(Int.MaxValue - 1, Int.MaxValue.toLong - 1) :: + Row(Int.MaxValue - 1, Int.MaxValue.toLong) :: + Row(Int.MaxValue, Int.MaxValue.toLong) :: + Row(Int.MaxValue, Int.MaxValue.toLong + 1) :: Nil + checkSort(df1, expected1, Array(IntegerType, LongType)) + + val df2 = spark.createDataFrame(Seq( + (-1, -1.toShort), + (Short.MinValue.toInt, Short.MinValue), + (1, 1.toShort), + (Short.MaxValue.toInt, (Short.MaxValue - 1).toShort), + (Short.MaxValue.toInt + 1, (Short.MaxValue - 1).toShort), + (Short.MaxValue.toInt, Short.MaxValue), + (Short.MaxValue.toInt + 1, Short.MaxValue))).toDF("c1", "c2") + val expected2 = + Row(Short.MinValue.toInt, Short.MinValue) :: + Row(-1, -1.toShort) :: + Row(1, 1.toShort) :: + Row(Short.MaxValue.toInt, Short.MaxValue - 1) :: + Row(Short.MaxValue.toInt, Short.MaxValue) :: + Row(Short.MaxValue.toInt + 1, Short.MaxValue - 1) :: + Row(Short.MaxValue.toInt + 1, Short.MaxValue) :: Nil + checkSort(df2, expected2, Array(IntegerType, ShortType)) + + val df3 = spark.createDataFrame(Seq( + (-1L, -1.toShort), + (Short.MinValue.toLong, Short.MinValue), + (1L, 1.toShort), + (Short.MaxValue.toLong, (Short.MaxValue - 1).toShort), + (Short.MaxValue.toLong + 1, (Short.MaxValue - 1).toShort), + (Short.MaxValue.toLong, Short.MaxValue), + (Short.MaxValue.toLong + 1, Short.MaxValue))).toDF("c1", "c2") + val expected3 = + Row(Short.MinValue.toLong, Short.MinValue) :: + Row(-1L, -1.toShort) :: + Row(1L, 1.toShort) :: + Row(Short.MaxValue.toLong, Short.MaxValue - 1) :: + Row(Short.MaxValue.toLong, Short.MaxValue) :: + Row(Short.MaxValue.toLong + 1, Short.MaxValue - 1) :: + Row(Short.MaxValue.toLong + 1, Short.MaxValue) :: Nil + checkSort(df3, expected3, Array(LongType, ShortType)) + } + + test("skip zorder if only requires one column") { + withTable("t") { + withSQLConf("spark.sql.hive.convertMetastoreParquet" -> "false") { + sql("CREATE TABLE t (c1 int, c2 string) stored as parquet") + val order1 = sql("OPTIMIZE t ZORDER BY c1").queryExecution.analyzed + .asInstanceOf[OptimizeZorderCommandBase].query.asInstanceOf[Sort].order.head.child + assert(!order1.isInstanceOf[Zorder]) + assert(order1.isInstanceOf[AttributeReference]) + } + } + } + + test("Add config to control if zorder using global sort") { + withTable("t") { + withSQLConf(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED.key -> "false") { + sql( + """ + |CREATE TABLE t (c1 int, c2 string) TBLPROPERTIES ( + |'kyuubi.zorder.enabled'= 'true', + |'kyuubi.zorder.cols'= 'c1,c2') + |""".stripMargin) + val p1 = sql("OPTIMIZE t ZORDER BY c1, c2").queryExecution.analyzed + assert(p1.collect { + case shuffle: Sort if !shuffle.global => shuffle + }.size == 1) + + val p2 = sql("INSERT INTO TABLE t SELECT * FROM VALUES(1,'a')").queryExecution.analyzed + assert(p2.collect { + case shuffle: Sort if !shuffle.global => shuffle + }.size == 1) + } + } + } + + test("fast approach test") { + Seq[Seq[Any]]( + Seq(1L, 2L), + Seq(1L, 2L, 3L), + Seq(1L, 2L, 3L, 4L), + Seq(1L, 2L, 3L, 4L, 5L), + Seq(1L, 2L, 3L, 4L, 5L, 6L), + Seq(1L, 2L, 3L, 4L, 5L, 6L, 7L), + Seq(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L)) + .foreach { inputs => + assert(java.util.Arrays.equals( + ZorderBytesUtils.interleaveBits(inputs.toArray), + ZorderBytesUtils.interleaveBitsDefault(inputs.map(ZorderBytesUtils.toByteArray).toArray))) + } + } + + test("OPTIMIZE command is parsed as expected") { + val parser = createParser + val globalSort = spark.conf.get(KyuubiSQLConf.ZORDER_GLOBAL_SORT_ENABLED) + + assert(parser.parsePlan("OPTIMIZE p zorder by c1") === + OptimizeZorderStatement( + Seq("p"), + Sort( + SortOrder(UnresolvedAttribute("c1"), Ascending, NullsLast, Seq.empty) :: Nil, + globalSort, + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(TableIdentifier("p")))))) + + assert(parser.parsePlan("OPTIMIZE p zorder by c1, c2") === + OptimizeZorderStatement( + Seq("p"), + Sort( + SortOrder( + Zorder(Seq(UnresolvedAttribute("c1"), UnresolvedAttribute("c2"))), + Ascending, + NullsLast, + Seq.empty) :: Nil, + globalSort, + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(TableIdentifier("p")))))) + + assert(parser.parsePlan("OPTIMIZE p where id = 1 zorder by c1") === + OptimizeZorderStatement( + Seq("p"), + Sort( + SortOrder(UnresolvedAttribute("c1"), Ascending, NullsLast, Seq.empty) :: Nil, + globalSort, + Project( + Seq(UnresolvedStar(None)), + Filter( + EqualTo(UnresolvedAttribute("id"), Literal(1)), + UnresolvedRelation(TableIdentifier("p"))))))) + + assert(parser.parsePlan("OPTIMIZE p where id = 1 zorder by c1, c2") === + OptimizeZorderStatement( + Seq("p"), + Sort( + SortOrder( + Zorder(Seq(UnresolvedAttribute("c1"), UnresolvedAttribute("c2"))), + Ascending, + NullsLast, + Seq.empty) :: Nil, + globalSort, + Project( + Seq(UnresolvedStar(None)), + Filter( + EqualTo(UnresolvedAttribute("id"), Literal(1)), + UnresolvedRelation(TableIdentifier("p"))))))) + + assert(parser.parsePlan("OPTIMIZE p where id = current_date() zorder by c1") === + OptimizeZorderStatement( + Seq("p"), + Sort( + SortOrder(UnresolvedAttribute("c1"), Ascending, NullsLast, Seq.empty) :: Nil, + globalSort, + Project( + Seq(UnresolvedStar(None)), + Filter( + EqualTo( + UnresolvedAttribute("id"), + UnresolvedFunction("current_date", Seq.empty, false)), + UnresolvedRelation(TableIdentifier("p"))))))) + + // TODO: add following case support + intercept[ParseException] { + parser.parsePlan("OPTIMIZE p zorder by (c1)") + } + + intercept[ParseException] { + parser.parsePlan("OPTIMIZE p zorder by (c1, c2)") + } + } + + test("OPTIMIZE partition predicates constraint") { + withTable("p") { + sql("CREATE TABLE p (c1 INT, c2 INT) PARTITIONED BY (event_date DATE)") + val e1 = intercept[KyuubiSQLExtensionException] { + sql("OPTIMIZE p WHERE event_date = current_date as c ZORDER BY c1, c2") + } + assert(e1.getMessage.contains("unsupported partition predicates")) + + val e2 = intercept[KyuubiSQLExtensionException] { + sql("OPTIMIZE p WHERE c1 = 1 ZORDER BY c1, c2") + } + assert(e2.getMessage == "Only partition column filters are allowed") + } + } + + def createParser: ParserInterface +} + +trait ZorderWithCodegenEnabledSuiteBase extends ZorderSuiteBase { + override def sparkConf(): SparkConf = { + val conf = super.sparkConf + conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + conf + } +} + +trait ZorderWithCodegenDisabledSuiteBase extends ZorderSuiteBase { + override def sparkConf(): SparkConf = { + val conf = super.sparkConf + conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") + conf.set(SQLConf.CODEGEN_FACTORY_MODE.key, "NO_CODEGEN") + conf + } +} diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/benchmark/KyuubiBenchmarkBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/benchmark/KyuubiBenchmarkBase.scala new file mode 100644 index 00000000000..b891a7224a0 --- /dev/null +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/benchmark/KyuubiBenchmarkBase.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.benchmark + +import java.io.{File, FileOutputStream, OutputStream} + +import scala.collection.JavaConverters._ + +import com.google.common.reflect.ClassPath +import org.scalatest.Assertions._ + +trait KyuubiBenchmarkBase { + var output: Option[OutputStream] = None + + private val prefix = { + val benchmarkClasses = ClassPath.from(Thread.currentThread.getContextClassLoader) + .getTopLevelClassesRecursive("org.apache.spark.sql").asScala.toArray + assert(benchmarkClasses.nonEmpty) + val benchmark = benchmarkClasses.find(_.load().getName.endsWith("Benchmark")) + val targetDirOrProjDir = + new File(benchmark.get.load().getProtectionDomain.getCodeSource.getLocation.toURI) + .getParentFile.getParentFile + if (targetDirOrProjDir.getName == "target") { + targetDirOrProjDir.getParentFile.getCanonicalPath + "/" + } else { + targetDirOrProjDir.getCanonicalPath + "/" + } + } + + def withHeader(func: => Unit): Unit = { + val version = System.getProperty("java.version").split("\\D+")(0).toInt + val jdkString = if (version > 8) s"-jdk$version" else "" + val resultFileName = + s"${this.getClass.getSimpleName.replace("$", "")}$jdkString-results.txt" + val dir = new File(s"${prefix}benchmarks/") + if (!dir.exists()) { + // scalastyle:off println + println(s"Creating ${dir.getAbsolutePath} for benchmark results.") + // scalastyle:on println + dir.mkdirs() + } + val file = new File(dir, resultFileName) + if (!file.exists()) { + file.createNewFile() + } + output = Some(new FileOutputStream(file)) + + func + + output.foreach { o => + if (o != null) { + o.close() + } + } + } +} diff --git a/extensions/spark/kyuubi-spark-authz/src/main/resources/table_command_spec.json b/extensions/spark/kyuubi-spark-authz/src/main/resources/table_command_spec.json index 3e191146862..06d76c7e530 100644 --- a/extensions/spark/kyuubi-spark-authz/src/main/resources/table_command_spec.json +++ b/extensions/spark/kyuubi-spark-authz/src/main/resources/table_command_spec.json @@ -157,7 +157,7 @@ }, { "classname" : "org.apache.spark.sql.catalyst.plans.logical.CreateTableAsSelect", "tableDescs" : [ { - "fieldName" : "left", + "fieldName" : "name", "fieldExtractor" : "ResolvedIdentifierTableExtractor", "columnDesc" : null, "actionTypeDesc" : null, @@ -178,7 +178,7 @@ "isInput" : false, "setCurrentDatabaseIfMissing" : false }, { - "fieldName" : "left", + "fieldName" : "name", "fieldExtractor" : "ResolvedDbObjectNameTableExtractor", "columnDesc" : null, "actionTypeDesc" : null, @@ -508,7 +508,7 @@ }, { "classname" : "org.apache.spark.sql.catalyst.plans.logical.ReplaceTableAsSelect", "tableDescs" : [ { - "fieldName" : "left", + "fieldName" : "name", "fieldExtractor" : "ResolvedIdentifierTableExtractor", "columnDesc" : null, "actionTypeDesc" : null, @@ -529,7 +529,7 @@ "isInput" : false, "setCurrentDatabaseIfMissing" : false }, { - "fieldName" : "left", + "fieldName" : "name", "fieldExtractor" : "ResolvedDbObjectNameTableExtractor", "columnDesc" : null, "actionTypeDesc" : null, diff --git a/extensions/spark/kyuubi-spark-authz/src/test/scala/org/apache/kyuubi/plugin/spark/authz/gen/TableCommands.scala b/extensions/spark/kyuubi-spark-authz/src/test/scala/org/apache/kyuubi/plugin/spark/authz/gen/TableCommands.scala index ca2ee92948e..6a6800210dc 100644 --- a/extensions/spark/kyuubi-spark-authz/src/test/scala/org/apache/kyuubi/plugin/spark/authz/gen/TableCommands.scala +++ b/extensions/spark/kyuubi-spark-authz/src/test/scala/org/apache/kyuubi/plugin/spark/authz/gen/TableCommands.scala @@ -234,9 +234,9 @@ object TableCommands { TableCommandSpec( cmd, Seq( - resolvedIdentifierTableDesc.copy(fieldName = "left"), + resolvedIdentifierTableDesc.copy(fieldName = "name"), tableDesc, - resolvedDbObjectNameDesc.copy(fieldName = "left")), + resolvedDbObjectNameDesc.copy(fieldName = "name")), CREATETABLE_AS_SELECT, Seq(queryQueryDesc)) } diff --git a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/Lineage.scala b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/Lineage.scala index 4bd0bd0b168..730deeb01e2 100644 --- a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/Lineage.scala +++ b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/Lineage.scala @@ -32,8 +32,9 @@ class Lineage( override def equals(other: Any): Boolean = other match { case otherLineage: Lineage => - otherLineage.inputTables == inputTables && otherLineage.outputTables == outputTables && - otherLineage.columnLineage == columnLineage + otherLineage.inputTables.toSet == inputTables.toSet && + otherLineage.outputTables.toSet == outputTables.toSet && + otherLineage.columnLineage.toSet == columnLineage.toSet case _ => false } diff --git a/integration-tests/kyuubi-flink-it/pom.xml b/integration-tests/kyuubi-flink-it/pom.xml index 3c0e3f31a7c..15699be1d8a 100644 --- a/integration-tests/kyuubi-flink-it/pom.xml +++ b/integration-tests/kyuubi-flink-it/pom.xml @@ -112,4 +112,8 @@ + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + diff --git a/integration-tests/kyuubi-hive-it/pom.xml b/integration-tests/kyuubi-hive-it/pom.xml index 24e5529a2d3..c4e9f320c95 100644 --- a/integration-tests/kyuubi-hive-it/pom.xml +++ b/integration-tests/kyuubi-hive-it/pom.xml @@ -69,4 +69,9 @@ test + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + diff --git a/integration-tests/kyuubi-jdbc-it/pom.xml b/integration-tests/kyuubi-jdbc-it/pom.xml index 08f74512e90..95ffd2038c1 100644 --- a/integration-tests/kyuubi-jdbc-it/pom.xml +++ b/integration-tests/kyuubi-jdbc-it/pom.xml @@ -114,5 +114,7 @@ + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes diff --git a/integration-tests/kyuubi-trino-it/pom.xml b/integration-tests/kyuubi-trino-it/pom.xml index 628f63818b9..c93d43c005b 100644 --- a/integration-tests/kyuubi-trino-it/pom.xml +++ b/integration-tests/kyuubi-trino-it/pom.xml @@ -88,4 +88,9 @@ + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/IcebergMetadataTests.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/IcebergMetadataTests.scala index e3bb4ccb730..99482f0c5ff 100644 --- a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/IcebergMetadataTests.scala +++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/IcebergMetadataTests.scala @@ -17,8 +17,11 @@ package org.apache.kyuubi.operation +import scala.collection.mutable.ListBuffer + import org.apache.kyuubi.{IcebergSuiteMixin, SPARK_COMPILE_VERSION} import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._ +import org.apache.kyuubi.util.AssertionUtils._ import org.apache.kyuubi.util.SparkVersionUtil trait IcebergMetadataTests extends HiveJDBCTestHelper with IcebergSuiteMixin with SparkVersionUtil { @@ -27,10 +30,11 @@ trait IcebergMetadataTests extends HiveJDBCTestHelper with IcebergSuiteMixin wit withJdbcStatement() { statement => val metaData = statement.getConnection.getMetaData val catalogs = metaData.getCatalogs - catalogs.next() - assert(catalogs.getString(TABLE_CAT) === "spark_catalog") - catalogs.next() - assert(catalogs.getString(TABLE_CAT) === catalog) + val results = ListBuffer[String]() + while (catalogs.next()) { + results += catalogs.getString(TABLE_CAT) + } + assertContains(results, "spark_catalog", catalog) } } diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala index 20d3f6fad5b..0ac56e3bcf0 100644 --- a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala +++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala @@ -270,7 +270,7 @@ trait SparkQueryTests extends SparkDataTypeTests with HiveJDBCTestHelper { |""".stripMargin val rs1 = statement.executeQuery(code) rs1.next() - assert(rs1.getString(1) startsWith "df: org.apache.spark.sql.DataFrame") + assert(rs1.getString(1) contains "df: org.apache.spark.sql.DataFrame") // continue val rs2 = statement.executeQuery("df.count()") @@ -311,7 +311,7 @@ trait SparkQueryTests extends SparkDataTypeTests with HiveJDBCTestHelper { |""".stripMargin val rs5 = statement.executeQuery(code2) rs5.next() - assert(rs5.getString(1) startsWith "df: org.apache.spark.sql.DataFrame") + assert(rs5.getString(1) contains "df: org.apache.spark.sql.DataFrame") // re-assign val rs6 = statement.executeQuery("result.set(df)") @@ -420,7 +420,7 @@ trait SparkQueryTests extends SparkDataTypeTests with HiveJDBCTestHelper { statement.execute(code1) val rs = statement.executeQuery(code2) rs.next() - assert(rs.getString(1) == "x: Int = 3") + assert(rs.getString(1) contains "x: Int = 3") } } diff --git a/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/AdminRestApi.java b/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/AdminRestApi.java index 3b220cbc234..904ecb6c9d6 100644 --- a/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/AdminRestApi.java +++ b/kyuubi-rest-client/src/main/java/org/apache/kyuubi/client/AdminRestApi.java @@ -17,10 +17,7 @@ package org.apache.kyuubi.client; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import org.apache.kyuubi.client.api.v1.dto.Engine; import org.apache.kyuubi.client.api.v1.dto.OperationData; import org.apache.kyuubi.client.api.v1.dto.ServerData; @@ -87,9 +84,17 @@ public List listEngines( } public List listSessions() { + return listSessions(Collections.emptyList()); + } + + public List listSessions(List users) { + Map params = new HashMap<>(); + if (users != null && !users.isEmpty()) { + params.put("users", String.join(",", users)); + } SessionData[] result = this.getClient() - .get(API_BASE_PATH + "/sessions", null, SessionData[].class, client.getAuthHeader()); + .get(API_BASE_PATH + "/sessions", params, SessionData[].class, client.getAuthHeader()); return Arrays.asList(result); } diff --git a/kyuubi-server/src/main/resources/sql/derby/005-KYUUBI-5327.derby.sql b/kyuubi-server/src/main/resources/sql/derby/005-KYUUBI-5327.derby.sql new file mode 100644 index 00000000000..32c44d0fb64 --- /dev/null +++ b/kyuubi-server/src/main/resources/sql/derby/005-KYUUBI-5327.derby.sql @@ -0,0 +1,3 @@ +ALTER TABLE metadata ADD COLUMN priority int NOT NULL DEFAULT 10; + +CREATE INDEX metadata_priority_create_time_index ON metadata(priority, create_time); diff --git a/kyuubi-server/src/main/resources/sql/derby/metadata-store-schema-1.8.0.derby.sql b/kyuubi-server/src/main/resources/sql/derby/metadata-store-schema-1.8.0.derby.sql index 8d333bda2bd..139f70d3b8b 100644 --- a/kyuubi-server/src/main/resources/sql/derby/metadata-store-schema-1.8.0.derby.sql +++ b/kyuubi-server/src/main/resources/sql/derby/metadata-store-schema-1.8.0.derby.sql @@ -26,6 +26,7 @@ CREATE TABLE metadata( engine_state varchar(32), -- the engine application state engine_error clob, -- the engine application diagnose end_time bigint, -- the metadata end time + priority int NOT NULL DEFAULT 10, -- the application priority, high value means high priority peer_instance_closed boolean default FALSE -- closed by peer kyuubi instance ); @@ -36,3 +37,5 @@ CREATE INDEX metadata_user_name_index ON metadata(user_name); CREATE INDEX metadata_engine_type_index ON metadata(engine_type); CREATE INDEX metadata_create_time_index ON metadata(create_time); + +CREATE INDEX metadata_priority_create_time_index ON metadata(priority, create_time); diff --git a/kyuubi-server/src/main/resources/sql/mysql/005-KYUUBI-5327.mysql.sql b/kyuubi-server/src/main/resources/sql/mysql/005-KYUUBI-5327.mysql.sql new file mode 100644 index 00000000000..0637e053d8d --- /dev/null +++ b/kyuubi-server/src/main/resources/sql/mysql/005-KYUUBI-5327.mysql.sql @@ -0,0 +1,13 @@ +SELECT '< KYUUBI-5327: Introduce priority in metadata' AS ' '; + +ALTER TABLE metadata ADD COLUMN priority int NOT NULL DEFAULT 10 COMMENT 'the application priority, high value means high priority'; + +-- In MySQL 5.7, A key_part specification can end with ASC or DESC. +-- These keywords are permitted for future extensions for specifying ascending or descending index value storage. +-- Currently, they are parsed but ignored; index values are always stored in ascending order. +-- In MySQL 8 this can take effect and this index will be hit if query order by priority DESC, create_time ASC. +-- See more detail in: +-- https://dev.mysql.com/doc/refman/8.0/en/index-hints.html +-- https://dev.mysql.com/doc/refman/8.0/en/create-index.html +-- https://dev.mysql.com/doc/refman/5.7/en/create-index.html +ALTER TABLE metadata ADD INDEX priority_create_time_index(priority DESC, create_time ASC); diff --git a/kyuubi-server/src/main/resources/sql/mysql/metadata-store-schema-1.8.0.mysql.sql b/kyuubi-server/src/main/resources/sql/mysql/metadata-store-schema-1.8.0.mysql.sql index 77df8fa0562..fb2019848d7 100644 --- a/kyuubi-server/src/main/resources/sql/mysql/metadata-store-schema-1.8.0.mysql.sql +++ b/kyuubi-server/src/main/resources/sql/mysql/metadata-store-schema-1.8.0.mysql.sql @@ -24,9 +24,12 @@ CREATE TABLE IF NOT EXISTS metadata( engine_state varchar(32) COMMENT 'the engine application state', engine_error mediumtext COMMENT 'the engine application diagnose', end_time bigint COMMENT 'the metadata end time', + priority int NOT NULL DEFAULT 10 COMMENT 'the application priority, high value means high priority', peer_instance_closed boolean default '0' COMMENT 'closed by peer kyuubi instance', UNIQUE INDEX unique_identifier_index(identifier), INDEX user_name_index(user_name), INDEX engine_type_index(engine_type), - INDEX create_time_index(create_time) + INDEX create_time_index(create_time), + -- See more detail about this index in ./005-KYUUBI-5327.mysql.sql + INDEX priority_create_time_index(priority DESC, create_time ASC) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; diff --git a/kyuubi-server/src/main/resources/sql/mysql/upgrade-1.7.0-to-1.8.0.mysql.sql b/kyuubi-server/src/main/resources/sql/mysql/upgrade-1.7.0-to-1.8.0.mysql.sql index 473997448ba..0dd3abfda55 100644 --- a/kyuubi-server/src/main/resources/sql/mysql/upgrade-1.7.0-to-1.8.0.mysql.sql +++ b/kyuubi-server/src/main/resources/sql/mysql/upgrade-1.7.0-to-1.8.0.mysql.sql @@ -1,4 +1,5 @@ SELECT '< Upgrading MetaStore schema from 1.7.0 to 1.8.0 >' AS ' '; SOURCE 003-KYUUBI-5078.mysql.sql; SOURCE 004-KYUUBI-5131.mysql.sql; +SOURCE 005-KYUUBI-5327.mysql.sql; SELECT '< Finished upgrading MetaStore schema from 1.7.0 to 1.8.0 >' AS ' '; diff --git a/kyuubi-server/src/main/resources/sql/sqlite/metadata-store-schema-1.8.0.sqlite.sql b/kyuubi-server/src/main/resources/sql/sqlite/metadata-store-schema-1.8.0.sqlite.sql index 656de6e5d62..aa50267eba0 100644 --- a/kyuubi-server/src/main/resources/sql/sqlite/metadata-store-schema-1.8.0.sqlite.sql +++ b/kyuubi-server/src/main/resources/sql/sqlite/metadata-store-schema-1.8.0.sqlite.sql @@ -24,6 +24,7 @@ CREATE TABLE IF NOT EXISTS metadata( engine_state varchar(32), -- the engine application state engine_error mediumtext, -- the engine application diagnose end_time bigint, -- the metadata end time + priority INTEGER NOT NULL DEFAULT 10, -- the application priority, high value means high priority peer_instance_closed boolean default '0' -- closed by peer kyuubi instance ); @@ -34,3 +35,5 @@ CREATE INDEX IF NOT EXISTS metadata_user_name_index ON metadata(user_name); CREATE INDEX IF NOT EXISTS metadata_engine_type_index ON metadata(engine_type); CREATE INDEX IF NOT EXISTS metadata_create_time_index ON metadata(create_time); + +CREATE INDEX IF NOT EXISTS metadata_priority_create_time_index ON metadata(priority, create_time); diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/ProcBuilder.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/ProcBuilder.scala index 44b317c71ea..84807a62d87 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/ProcBuilder.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/ProcBuilder.scala @@ -17,7 +17,7 @@ package org.apache.kyuubi.engine -import java.io.{File, FilenameFilter, IOException} +import java.io.{File, FileFilter, IOException} import java.net.URI import java.nio.charset.StandardCharsets import java.nio.file.{Files, Path, Paths} @@ -56,13 +56,14 @@ trait ProcBuilder { } } + protected val engineScalaBinaryVersion: String = SCALA_COMPILE_VERSION + /** * The engine jar or other runnable jar containing the main method */ def mainResource: Option[String] = { // 1. get the main resource jar for user specified config first - // TODO use SPARK_SCALA_VERSION instead of SCALA_COMPILE_VERSION - val jarName = s"${module}_$SCALA_COMPILE_VERSION-$KYUUBI_VERSION.jar" + val jarName: String = s"${module}_$engineScalaBinaryVersion-$KYUUBI_VERSION.jar" conf.getOption(s"kyuubi.session.engine.$shortName.main.resource").filter { userSpecified => // skip check exist if not local file. val uri = new URI(userSpecified) @@ -295,6 +296,11 @@ trait ProcBuilder { } } + protected lazy val engineHomeDirFilter: FileFilter = file => { + val fileName = file.getName + file.isDirectory && fileName.contains(s"$shortName-") && !fileName.contains("-engine") + } + /** * Get the home directly that contains binary distributions of engines. * @@ -311,9 +317,6 @@ trait ProcBuilder { * @return SPARK_HOME, HIVE_HOME, etc. */ protected def getEngineHome(shortName: String): String = { - val homeDirFilter: FilenameFilter = (dir: File, name: String) => - dir.isDirectory && name.contains(s"$shortName-") && !name.contains("-engine") - val homeKey = s"${shortName.toUpperCase}_HOME" // 1. get from env, e.g. SPARK_HOME, FLINK_HOME env.get(homeKey) @@ -321,14 +324,14 @@ trait ProcBuilder { // 2. get from $KYUUBI_HOME/externals/kyuubi-download/target env.get(KYUUBI_HOME).flatMap { p => val candidates = Paths.get(p, "externals", "kyuubi-download", "target") - .toFile.listFiles(homeDirFilter) + .toFile.listFiles(engineHomeDirFilter) if (candidates == null) None else candidates.map(_.toPath).headOption }.filter(Files.exists(_)).map(_.toAbsolutePath.toFile.getCanonicalPath) }.orElse { // 3. get from kyuubi-server/../externals/kyuubi-download/target Utils.getCodeSourceLocation(getClass).split("kyuubi-server").flatMap { cwd => val candidates = Paths.get(cwd, "externals", "kyuubi-download", "target") - .toFile.listFiles(homeDirFilter) + .toFile.listFiles(engineHomeDirFilter) if (candidates == null) None else candidates.map(_.toPath).headOption }.find(Files.exists(_)).map(_.toAbsolutePath.toFile.getCanonicalPath) } match { diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/spark/SparkProcessBuilder.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/spark/SparkProcessBuilder.scala index 351eddb7567..02f4064afc6 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/spark/SparkProcessBuilder.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/spark/SparkProcessBuilder.scala @@ -17,7 +17,7 @@ package org.apache.kyuubi.engine.spark -import java.io.{File, IOException} +import java.io.{File, FileFilter, IOException} import java.nio.file.Paths import java.util.Locale @@ -25,6 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import com.google.common.annotations.VisibleForTesting +import org.apache.commons.lang3.StringUtils import org.apache.hadoop.security.UserGroupInformation import org.apache.kyuubi._ @@ -35,8 +36,7 @@ import org.apache.kyuubi.engine.ProcBuilder.KYUUBI_ENGINE_LOG_PATH_KEY import org.apache.kyuubi.ha.HighAvailabilityConf import org.apache.kyuubi.ha.client.AuthTypes import org.apache.kyuubi.operation.log.OperationLog -import org.apache.kyuubi.util.KubernetesUtils -import org.apache.kyuubi.util.Validator +import org.apache.kyuubi.util.{KubernetesUtils, Validator} class SparkProcessBuilder( override val proxyUser: String, @@ -102,6 +102,25 @@ class SparkProcessBuilder( } } + private[kyuubi] def extractSparkCoreScalaVersion(fileNames: Iterable[String]): String = { + fileNames.collectFirst { case SPARK_CORE_SCALA_VERSION_REGEX(scalaVersion) => scalaVersion } + .getOrElse(throw new KyuubiException("Failed to extract Scala version from spark-core jar")) + } + + override protected val engineScalaBinaryVersion: String = { + val sparkCoreScalaVersion = + extractSparkCoreScalaVersion(Paths.get(sparkHome, "jars").toFile.list()) + StringUtils.defaultIfBlank(System.getenv("SPARK_SCALA_VERSION"), sparkCoreScalaVersion) + } + + override protected lazy val engineHomeDirFilter: FileFilter = file => { + val r = SCALA_COMPILE_VERSION match { + case "2.12" => SPARK_HOME_REGEX_SCALA_212 + case "2.13" => SPARK_HOME_REGEX_SCALA_213 + } + file.isDirectory && r.findFirstMatchIn(file.getName).isDefined + } + override protected lazy val commands: Array[String] = { // complete `spark.master` if absent on kubernetes completeMasterUrl(conf) @@ -314,4 +333,13 @@ object SparkProcessBuilder { final private val SPARK_SUBMIT_FILE = if (Utils.isWindows) "spark-submit.cmd" else "spark-submit" final private val SPARK_CONF_DIR = "SPARK_CONF_DIR" final private val SPARK_CONF_FILE_NAME = "spark-defaults.conf" + + final private[kyuubi] val SPARK_CORE_SCALA_VERSION_REGEX = + """^spark-core_(\d\.\d+).*.jar$""".r + + final private[kyuubi] val SPARK_HOME_REGEX_SCALA_212 = + """^spark-\d+\.\d+\.\d+-bin-hadoop\d+(\.\d+)?$""".r + + final private[kyuubi] val SPARK_HOME_REGEX_SCALA_213 = + """^spark-\d+\.\d+\.\d+-bin-hadoop\d(\.\d+)?+-scala\d+(\.\d+)?$""".r } diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/AdminResource.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/AdminResource.scala index 5f410ab7de9..3c6f2a19782 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/AdminResource.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/api/v1/AdminResource.scala @@ -298,15 +298,15 @@ private[v1] class AdminResource extends ApiRequestContext with Logging { } val engines = ListBuffer[Engine]() val engineSpace = fe.getConf.get(HA_NAMESPACE) - val shareLevel = fe.getConf.get(ENGINE_SHARE_LEVEL) - val engineType = fe.getConf.get(ENGINE_TYPE) + val finalShareLevel = Option(shareLevel).getOrElse(fe.getConf.get(ENGINE_SHARE_LEVEL)) + val finalEngineType = Option(engineType).getOrElse(fe.getConf.get(ENGINE_TYPE)) withDiscoveryClient(fe.getConf) { discoveryClient => - val commonParent = s"/${engineSpace}_${KYUUBI_VERSION}_${shareLevel}_$engineType" + val commonParent = s"/${engineSpace}_${KYUUBI_VERSION}_${finalShareLevel}_$finalEngineType" info(s"Listing engine nodes for $commonParent") try { discoveryClient.getChildren(commonParent).map { user => - val engine = getEngine(user, engineType, shareLevel, "", "") + val engine = getEngine(user, finalEngineType, finalShareLevel, "", "") val engineSpace = getEngineSpace(engine) discoveryClient.getChildren(engineSpace).map { child => info(s"Listing engine nodes for $engineSpace/$child") @@ -324,9 +324,12 @@ private[v1] class AdminResource extends ApiRequestContext with Logging { } } catch { case nne: NoNodeException => - error(s"No such engine for engine type: $engineType, share level: $shareLevel", nne) + error( + s"No such engine for engine type: $finalEngineType," + + s" share level: $finalShareLevel", + nne) throw new NotFoundException( - s"No such engine for engine type: $engineType, share level: $shareLevel") + s"No such engine for engine type: $finalEngineType, share level: $finalShareLevel") } } return engines.toSeq diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/session/KyuubiBatchSession.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/session/KyuubiBatchSession.scala index e10230ebfa0..8e4c5137fbf 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/session/KyuubiBatchSession.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/session/KyuubiBatchSession.scala @@ -61,9 +61,13 @@ class KyuubiBatchSession( override def createTime: Long = metadata.map(_.createTime).getOrElse(super.createTime) override def getNoOperationTime: Long = { - if (batchJobSubmissionOp != null && !OperationState.isTerminal( - batchJobSubmissionOp.getStatus.state)) { - 0L + if (batchJobSubmissionOp != null) { + val batchStatus = batchJobSubmissionOp.getStatus + if (!OperationState.isTerminal(batchStatus.state)) { + 0L + } else { + System.currentTimeMillis() - batchStatus.completed + } } else { super.getNoOperationTime } diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/SparkProcessBuilderSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/SparkProcessBuilderSuite.scala index a4227d26e74..408f42f6404 100644 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/SparkProcessBuilderSuite.scala +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/SparkProcessBuilderSuite.scala @@ -26,7 +26,7 @@ import java.util.concurrent.{Executors, TimeUnit} import org.scalatest.time.SpanSugar._ import org.scalatestplus.mockito.MockitoSugar -import org.apache.kyuubi.{KerberizedTestHelper, KyuubiSQLException, Utils} +import org.apache.kyuubi._ import org.apache.kyuubi.config.KyuubiConf import org.apache.kyuubi.config.KyuubiConf.{ENGINE_LOG_TIMEOUT, ENGINE_SPARK_MAIN_RESOURCE} import org.apache.kyuubi.engine.ProcBuilder.KYUUBI_ENGINE_LOG_PATH_KEY @@ -34,6 +34,7 @@ import org.apache.kyuubi.engine.spark.SparkProcessBuilder._ import org.apache.kyuubi.ha.HighAvailabilityConf import org.apache.kyuubi.ha.client.AuthTypes import org.apache.kyuubi.service.ServiceUtils +import org.apache.kyuubi.util.AssertionUtils._ class SparkProcessBuilderSuite extends KerberizedTestHelper with MockitoSugar { private def conf = KyuubiConf().set("kyuubi.on", "off") @@ -363,6 +364,46 @@ class SparkProcessBuilderSuite extends KerberizedTestHelper with MockitoSugar { .appendPodNameConf(conf3).get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) assert(execPodNamePrefix3 === Some(s"kyuubi-$engineRefId")) } + + test("extract spark core scala version") { + val builder = new SparkProcessBuilder("kentyao", KyuubiConf(false)) + Seq( + "spark-core_2.13-3.4.1.jar", + "spark-core_2.13-3.5.0-abc-20230921.jar", + "spark-core_2.13-3.5.0-xyz-1.2.3.jar", + "spark-core_2.13-3.5.0.1.jar", + "spark-core_2.13.2-3.5.0.jar").foreach { f => + assertResult("2.13")(builder.extractSparkCoreScalaVersion(Seq(f))) + } + + Seq( + "spark-dummy_2.13-3.5.0.jar", + "spark-core_2.13-3.5.0.1.zip", + "yummy-spark-core_2.13-3.5.0.jar").foreach { f => + assertThrows[KyuubiException](builder.extractSparkCoreScalaVersion(Seq(f))) + } + } + + test("match scala version of spark home") { + SCALA_COMPILE_VERSION match { + case "2.12" => Seq( + "spark-3.2.4-bin-hadoop3.2", + "spark-3.2.4-bin-hadoop2.7", + "spark-3.4.1-bin-hadoop3") + .foreach { sparkHome => + assertMatches(sparkHome, SPARK_HOME_REGEX_SCALA_212) + assertNotMatches(sparkHome, SPARK_HOME_REGEX_SCALA_213) + } + case "2.13" => Seq( + "spark-3.2.4-bin-hadoop3.2-scala2.13", + "spark-3.4.1-bin-hadoop3-scala2.13", + "spark-3.5.0-bin-hadoop3-scala2.13") + .foreach { sparkHome => + assertMatches(sparkHome, SPARK_HOME_REGEX_SCALA_213) + assertNotMatches(sparkHome, SPARK_HOME_REGEX_SCALA_212) + } + } + } } class FakeSparkProcessBuilder(config: KyuubiConf) diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/api/v1/AdminResourceSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/api/v1/AdminResourceSuite.scala index 6ca00c802c9..ea87e3ea0d8 100644 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/api/v1/AdminResourceSuite.scala +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/api/v1/AdminResourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.kyuubi.server.api.v1 import java.nio.charset.StandardCharsets +import java.time.Duration import java.util.{Base64, UUID} import javax.ws.rs.client.Entity import javax.ws.rs.core.{GenericType, MediaType} @@ -49,6 +50,7 @@ class AdminResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper { override protected lazy val conf: KyuubiConf = KyuubiConf() .set(KyuubiConf.SERVER_ADMINISTRATORS, Set("admin001")) + .set(KyuubiConf.ENGINE_IDLE_TIMEOUT, Duration.ofMinutes(3).toMillis) private val encodeAuthorization: String = { new String( @@ -275,7 +277,6 @@ class AdminResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper { conf.set(KyuubiConf.ENGINE_TYPE, SPARK_SQL.toString) conf.set(KyuubiConf.FRONTEND_THRIFT_BINARY_BIND_PORT, 0) conf.set(HighAvailabilityConf.HA_NAMESPACE, "kyuubi_test") - conf.set(KyuubiConf.ENGINE_IDLE_TIMEOUT, 180000L) conf.set(KyuubiConf.GROUP_PROVIDER, "hadoop") val engine = @@ -320,7 +321,6 @@ class AdminResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper { conf.set(KyuubiConf.ENGINE_TYPE, SPARK_SQL.toString) conf.set(KyuubiConf.FRONTEND_THRIFT_BINARY_BIND_PORT, 0) conf.set(HighAvailabilityConf.HA_NAMESPACE, "kyuubi_test") - conf.set(KyuubiConf.ENGINE_IDLE_TIMEOUT, 180000L) conf.set(KyuubiConf.GROUP_PROVIDER, "hadoop") val engine = @@ -366,7 +366,6 @@ class AdminResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper { conf.set(KyuubiConf.ENGINE_TYPE, SPARK_SQL.toString) conf.set(KyuubiConf.FRONTEND_THRIFT_BINARY_BIND_PORT, 0) conf.set(HighAvailabilityConf.HA_NAMESPACE, "kyuubi_test") - conf.set(KyuubiConf.ENGINE_IDLE_TIMEOUT, 180000L) conf.set(KyuubiConf.GROUP_PROVIDER, "hadoop") val id = UUID.randomUUID().toString @@ -401,7 +400,6 @@ class AdminResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper { conf.set(KyuubiConf.ENGINE_TYPE, SPARK_SQL.toString) conf.set(KyuubiConf.FRONTEND_THRIFT_BINARY_BIND_PORT, 0) conf.set(HighAvailabilityConf.HA_NAMESPACE, "kyuubi_test") - conf.set(KyuubiConf.ENGINE_IDLE_TIMEOUT, 180000L) conf.set(KyuubiConf.GROUP_PROVIDER, "hadoop") val engine = @@ -446,7 +444,6 @@ class AdminResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper { conf.set(KyuubiConf.ENGINE_TYPE, SPARK_SQL.toString) conf.set(KyuubiConf.FRONTEND_THRIFT_BINARY_BIND_PORT, 0) conf.set(HighAvailabilityConf.HA_NAMESPACE, "kyuubi_test") - conf.set(KyuubiConf.ENGINE_IDLE_TIMEOUT, 180000L) conf.set(KyuubiConf.GROUP_PROVIDER, "hadoop") val engine = @@ -492,7 +489,6 @@ class AdminResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper { conf.set(KyuubiConf.ENGINE_TYPE, SPARK_SQL.toString) conf.set(KyuubiConf.FRONTEND_THRIFT_BINARY_BIND_PORT, 0) conf.set(HighAvailabilityConf.HA_NAMESPACE, "kyuubi_test") - conf.set(KyuubiConf.ENGINE_IDLE_TIMEOUT, 180000L) conf.set(KyuubiConf.GROUP_PROVIDER, "hadoop") val engineSpace = DiscoveryPaths.makePath( diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/session/SessionSigningSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/session/SessionSigningSuite.scala index a53e4650727..11121a74fb0 100644 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/session/SessionSigningSuite.scala +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/session/SessionSigningSuite.scala @@ -86,8 +86,9 @@ class SessionSigningSuite extends WithKyuubiServer with HiveJDBCTestHelper { assert(rs2.next()) // skipping prefix "res0: String = " of returned scala result - val publicKeyStr = rs1.getString(1).substring(15) - val sessionUserSign = rs2.getString(1).substring(15) + val sep = " = " + val publicKeyStr = StringUtils.substringAfter(rs1.getString(1), sep) + val sessionUserSign = StringUtils.substringAfter(rs2.getString(1), sep) assert(StringUtils.isNotBlank(publicKeyStr)) assert(StringUtils.isNotBlank(sessionUserSign)) diff --git a/kyuubi-util-scala/src/test/scala/org/apache/kyuubi/util/AssertionUtils.scala b/kyuubi-util-scala/src/test/scala/org/apache/kyuubi/util/AssertionUtils.scala index 2e9a0d3fe2d..9d33993b9d2 100644 --- a/kyuubi-util-scala/src/test/scala/org/apache/kyuubi/util/AssertionUtils.scala +++ b/kyuubi-util-scala/src/test/scala/org/apache/kyuubi/util/AssertionUtils.scala @@ -23,14 +23,16 @@ import java.util.Locale import scala.collection.Traversable import scala.io.Source import scala.reflect.ClassTag +import scala.util.matching.Regex -import org.scalactic.{source, Prettifier} +import org.scalactic.Prettifier +import org.scalactic.source.Position import org.scalatest.Assertions._ object AssertionUtils { def assertEqualsIgnoreCase(expected: AnyRef)(actual: AnyRef)( - implicit pos: source.Position): Unit = { + implicit pos: Position): Unit = { val isEqualsIgnoreCase = (Option(expected), Option(actual)) match { case (Some(expectedStr: String), Some(actualStr: String)) => expectedStr.equalsIgnoreCase(actualStr) @@ -44,15 +46,15 @@ object AssertionUtils { } } - def assertStartsWithIgnoreCase(expectedPrefix: String)(actual: String)( - implicit pos: source.Position): Unit = { + def assertStartsWithIgnoreCase(expectedPrefix: String)(actual: String)(implicit + pos: Position): Unit = { if (!actual.toLowerCase(Locale.ROOT).startsWith(expectedPrefix.toLowerCase(Locale.ROOT))) { fail(s"Expected starting with '$expectedPrefix' ignoring case, but got [$actual]")(pos) } } - def assertExistsIgnoreCase(expected: String)(actual: Iterable[String])( - implicit pos: source.Position): Unit = { + def assertExistsIgnoreCase(expected: String)(actual: Iterable[String])(implicit + pos: Position): Unit = { if (!actual.exists(_.equalsIgnoreCase(expected))) { fail(s"Expected containing '$expected' ignoring case, but got [$actual]")(pos) } @@ -73,7 +75,7 @@ object AssertionUtils { regenScript: String, splitFirstExpectedLine: Boolean = false)(implicit prettifier: Prettifier, - pos: source.Position): Unit = { + pos: Position): Unit = { val fileSource = Source.fromFile(path.toUri, StandardCharsets.UTF_8.name()) try { def expectedLinesIter = if (splitFirstExpectedLine) { @@ -104,13 +106,44 @@ object AssertionUtils { } } + /** + * Assert the iterable contains all the expected elements + */ + def assertContains(actual: TraversableOnce[AnyRef], expected: AnyRef*)(implicit + prettifier: Prettifier, + pos: Position): Unit = + withClue(s", expected containing [${expected.mkString(", ")}]") { + val actualSeq = actual.toSeq + expected.foreach { elem => assert(actualSeq.contains(elem))(prettifier, pos) } + } + + /** + * Asserts the string matches the regex + */ + def assertMatches(actual: String, regex: Regex)(implicit + prettifier: Prettifier, + pos: Position): Unit = + withClue(s"'$actual' expected matching the regex '$regex'") { + assert(regex.findFirstMatchIn(actual).isDefined)(prettifier, pos) + } + + /** + * Asserts the string does not match the regex + */ + def assertNotMatches(actual: String, regex: Regex)(implicit + prettifier: Prettifier, + pos: Position): Unit = + withClue(s"'$actual' expected not matching the regex '$regex'") { + assert(regex.findFirstMatchIn(actual).isEmpty)(prettifier, pos) + } + /** * Asserts that the given function throws an exception of the given type * and with the exception message equals to expected string */ def interceptEquals[T <: Exception](f: => Any)(expected: String)(implicit classTag: ClassTag[T], - pos: source.Position): Unit = { + pos: Position): Unit = { assert(expected != null) val exception = intercept[T](f)(classTag, pos) assertResult(expected)(exception.getMessage) @@ -122,7 +155,7 @@ object AssertionUtils { */ def interceptContains[T <: Exception](f: => Any)(contained: String)(implicit classTag: ClassTag[T], - pos: source.Position): Unit = { + pos: Position): Unit = { assert(contained != null) val exception = intercept[T](f)(classTag, pos) assert(exception.getMessage.contains(contained)) diff --git a/pom.xml b/pom.xml index 7e904e97b5d..df1e0c3b706 100644 --- a/pom.xml +++ b/pom.xml @@ -138,7 +138,7 @@ 0.9.3 0.62.2 1.17.1 - flink-${flink.version}-bin-scala_${scala.binary.version}.tgz + flink-${flink.version}-bin-scala_2.12.tgz ${apache.archive.dist}/flink/flink-${flink.version} false 3.0.2 @@ -158,7 +158,7 @@ false 4.5.14 4.4.16 - 1.3.1 + 1.4.0 2.15.0 4.0.4 2.3.2 @@ -198,7 +198,8 @@ --> 3.4.1 3.4 - spark-${spark.version}-bin-hadoop3.tgz + + spark-${spark.version}-bin-hadoop3${spark.archive.scala.suffix}.tgz ${apache.archive.dist}/spark/spark-${spark.version} false 3.42.0.0 @@ -2135,6 +2136,7 @@ 2.13 2.13.8 + -scala${scala.binary.version} @@ -2190,6 +2192,8 @@ 3.1.3 3.1 1.0.1 + + 1.3.1 spark-${spark.version}-bin-hadoop3.2.tgz org.scalatest.tags.Slow @@ -2205,7 +2209,7 @@ 3.2.4 3.2 2.0.2 - spark-${spark.version}-bin-hadoop3.2.tgz + spark-${spark.version}-bin-hadoop3.2${spark.archive.scala.suffix}.tgz org.scalatest.tags.Slow @@ -2239,6 +2243,19 @@ + + spark-3.5 + + extensions/spark/kyuubi-extension-spark-3-5 + + + 2.4.0 + 3.5.0 + 3.5 + org.scalatest.tags.Slow,org.apache.kyuubi.tags.DeltaTest,org.apache.kyuubi.tags.IcebergTest,org.apache.kyuubi.tags.PySparkTest + + + spark-master