diff --git a/LICENSE b/LICENSE index c2b0d72663b55..b771bd552b762 100644 --- a/LICENSE +++ b/LICENSE @@ -201,102 +201,61 @@ limitations under the License. -======================================================================= -Apache Spark Subcomponents: - -The Apache Spark project contains subcomponents with separate copyright -notices and license terms. Your use of the source code for the these -subcomponents is subject to the terms and conditions of the following -licenses. - - -======================================================================== -For heapq (pyspark/heapq3.py): -======================================================================== - -See license/LICENSE-heapq.txt - -======================================================================== -For SnapTree: -======================================================================== - -See license/LICENSE-SnapTree.txt - -======================================================================== -For jbcrypt: -======================================================================== - -See license/LICENSE-jbcrypt.txt - -======================================================================== -BSD-style licenses -======================================================================== - -The following components are provided under a BSD-style license. See project link for details. -The text of each license is also included at licenses/LICENSE-[project].txt. - - (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) - (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) - (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) - (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) - (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) - (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) - (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) - (BSD) JLine (jline:jline:0.9.94 - http://jline.sourceforge.net) - (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.3 - http://paranamer.codehaus.org/paranamer) - (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.6 - http://paranamer.codehaus.org/paranamer) - (BSD 3 Clause) Scala (http://www.scala-lang.org/download/#License) - (Interpreter classes (all .scala files in repl/src/main/scala - except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), - and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.8 - http://www.scala-lang.org/) - (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) - (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) - (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) - (New BSD License) Kryo (com.esotericsoftware:kryo:3.0.3 - https://github.com/EsotericSoftware/kryo) - (New BSD License) MinLog (com.esotericsoftware:minlog:1.3.0 - https://github.com/EsotericSoftware/minlog) - (New BSD license) Protocol Buffer Java API (com.google.protobuf:protobuf-java:2.5.0 - http://code.google.com/p/protobuf) - (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) - (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) - (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/) - (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) - (BSD licence) sbt and sbt-launch-lib.bash - (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) - (BSD 3 Clause) DPark (https://github.com/douban/dpark/blob/master/LICENSE) - (BSD 3 Clause) CloudPickle (https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE) - (BSD 2 Clause) Zstd-jni (https://github.com/luben/zstd-jni/blob/master/LICENSE) - (BSD license) Zstd (https://github.com/facebook/zstd/blob/v1.3.1/LICENSE) - -======================================================================== -MIT licenses -======================================================================== - -The following components are provided under the MIT License. See project link for details. -The text of each license is also included at licenses/LICENSE-[project].txt. - - (MIT License) JCL 1.1.1 implemented over SLF4J (org.slf4j:jcl-over-slf4j:1.7.5 - http://www.slf4j.org) - (MIT License) JUL to SLF4J bridge (org.slf4j:jul-to-slf4j:1.7.5 - http://www.slf4j.org) - (MIT License) SLF4J API Module (org.slf4j:slf4j-api:1.7.5 - http://www.slf4j.org) - (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) - (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) - (MIT License) scopt (com.github.scopt:scopt_2.11:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) - (MIT License) jquery (https://jquery.org/license/) - (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) - (MIT License) graphlib-dot (https://github.com/cpettitt/graphlib-dot) - (MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3) - (MIT License) sorttable (https://github.com/stuartlangridge/sorttable) - (MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE) - (MIT License) datatables (http://datatables.net/license) - (MIT License) mustache (https://github.com/mustache/mustache/blob/master/LICENSE) - (MIT License) cookies (http://code.google.com/p/cookies/wiki/License) - (MIT License) blockUI (http://jquery.malsup.com/block/) - (MIT License) RowsGroup (http://datatables.net/license/mit) - (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) - (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) - (MIT License) machinist (https://github.com/typelevel/machinist) +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. + + +Apache Software Foundation License 2.0 +-------------------------------------- + +common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +BSD 3-Clause +------------ + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg \ No newline at end of file diff --git a/LICENSE-binary b/LICENSE-binary new file mode 100644 index 0000000000000..c033dd8ad2e6a --- /dev/null +++ b/LICENSE-binary @@ -0,0 +1,520 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ +This project bundles some components that are also licensed under the Apache +License Version 2.0: + +commons-beanutils:commons-beanutils +org.apache.zookeeper:zookeeper +oro:oro +commons-configuration:commons-configuration +commons-digester:commons-digester +com.chuusai:shapeless_2.11 +com.googlecode.javaewah:JavaEWAH +com.twitter:chill-java +com.twitter:chill_2.11 +com.univocity:univocity-parsers +javax.jdo:jdo-api +joda-time:joda-time +net.sf.opencsv:opencsv +org.apache.derby:derby +org.objenesis:objenesis +org.roaringbitmap:RoaringBitmap +org.scalanlp:breeze-macros_2.11 +org.scalanlp:breeze_2.11 +org.typelevel:macro-compat_2.11 +org.yaml:snakeyaml +org.apache.xbean:xbean-asm5-shaded +com.squareup.okhttp3:logging-interceptor +com.squareup.okhttp3:okhttp +com.squareup.okio:okio +net.java.dev.jets3t:jets3t +org.apache.spark:spark-catalyst_2.11 +org.apache.spark:spark-kvstore_2.11 +org.apache.spark:spark-launcher_2.11 +org.apache.spark:spark-mllib-local_2.11 +org.apache.spark:spark-network-common_2.11 +org.apache.spark:spark-network-shuffle_2.11 +org.apache.spark:spark-sketch_2.11 +org.apache.spark:spark-tags_2.11 +org.apache.spark:spark-unsafe_2.11 +commons-httpclient:commons-httpclient +com.vlkan:flatbuffers +com.ning:compress-lzf +io.airlift:aircompressor +io.dropwizard.metrics:metrics-core +io.dropwizard.metrics:metrics-ganglia +io.dropwizard.metrics:metrics-graphite +io.dropwizard.metrics:metrics-json +io.dropwizard.metrics:metrics-jvm +org.iq80.snappy:snappy +com.clearspring.analytics:stream +com.jamesmurty.utils:java-xmlbuilder +commons-codec:commons-codec +commons-collections:commons-collections +io.fabric8:kubernetes-client +io.fabric8:kubernetes-model +io.netty:netty +io.netty:netty-all +net.hydromatic:eigenbase-properties +net.sf.supercsv:super-csv +org.apache.arrow:arrow-format +org.apache.arrow:arrow-memory +org.apache.arrow:arrow-vector +org.apache.calcite:calcite-avatica +org.apache.calcite:calcite-core +org.apache.calcite:calcite-linq4j +org.apache.commons:commons-crypto +org.apache.commons:commons-lang3 +org.apache.hadoop:hadoop-annotations +org.apache.hadoop:hadoop-auth +org.apache.hadoop:hadoop-client +org.apache.hadoop:hadoop-common +org.apache.hadoop:hadoop-hdfs +org.apache.hadoop:hadoop-mapreduce-client-app +org.apache.hadoop:hadoop-mapreduce-client-common +org.apache.hadoop:hadoop-mapreduce-client-core +org.apache.hadoop:hadoop-mapreduce-client-jobclient +org.apache.hadoop:hadoop-mapreduce-client-shuffle +org.apache.hadoop:hadoop-yarn-api +org.apache.hadoop:hadoop-yarn-client +org.apache.hadoop:hadoop-yarn-common +org.apache.hadoop:hadoop-yarn-server-common +org.apache.hadoop:hadoop-yarn-server-web-proxy +org.apache.httpcomponents:httpclient +org.apache.httpcomponents:httpcore +org.apache.orc:orc-core +org.apache.orc:orc-mapreduce +org.mortbay.jetty:jetty +org.mortbay.jetty:jetty-util +com.jolbox:bonecp +org.json4s:json4s-ast_2.11 +org.json4s:json4s-core_2.11 +org.json4s:json4s-jackson_2.11 +org.json4s:json4s-scalap_2.11 +com.carrotsearch:hppc +com.fasterxml.jackson.core:jackson-annotations +com.fasterxml.jackson.core:jackson-core +com.fasterxml.jackson.core:jackson-databind +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml +com.fasterxml.jackson.module:jackson-module-jaxb-annotations +com.fasterxml.jackson.module:jackson-module-paranamer +com.fasterxml.jackson.module:jackson-module-scala_2.11 +com.github.mifmif:generex +com.google.code.findbugs:jsr305 +com.google.code.gson:gson +com.google.inject:guice +com.google.inject.extensions:guice-servlet +com.twitter:parquet-hadoop-bundle +commons-beanutils:commons-beanutils-core +commons-cli:commons-cli +commons-dbcp:commons-dbcp +commons-io:commons-io +commons-lang:commons-lang +commons-logging:commons-logging +commons-net:commons-net +commons-pool:commons-pool +io.fabric8:zjsonpatch +javax.inject:javax.inject +javax.validation:validation-api +log4j:apache-log4j-extras +log4j:log4j +net.sf.jpam:jpam +org.apache.avro:avro +org.apache.avro:avro-ipc +org.apache.avro:avro-mapred +org.apache.commons:commons-compress +org.apache.commons:commons-math3 +org.apache.curator:curator-client +org.apache.curator:curator-framework +org.apache.curator:curator-recipes +org.apache.directory.api:api-asn1-api +org.apache.directory.api:api-util +org.apache.directory.server:apacheds-i18n +org.apache.directory.server:apacheds-kerberos-codec +org.apache.htrace:htrace-core +org.apache.ivy:ivy +org.apache.mesos:mesos +org.apache.parquet:parquet-column +org.apache.parquet:parquet-common +org.apache.parquet:parquet-encoding +org.apache.parquet:parquet-format +org.apache.parquet:parquet-hadoop +org.apache.parquet:parquet-jackson +org.apache.thrift:libfb303 +org.apache.thrift:libthrift +org.codehaus.jackson:jackson-core-asl +org.codehaus.jackson:jackson-mapper-asl +org.datanucleus:datanucleus-api-jdo +org.datanucleus:datanucleus-core +org.datanucleus:datanucleus-rdbms +org.lz4:lz4-java +org.spark-project.hive:hive-beeline +org.spark-project.hive:hive-cli +org.spark-project.hive:hive-exec +org.spark-project.hive:hive-jdbc +org.spark-project.hive:hive-metastore +org.xerial.snappy:snappy-java +stax:stax-api +xerces:xercesImpl +org.codehaus.jackson:jackson-jaxrs +org.codehaus.jackson:jackson-xc +org.eclipse.jetty:jetty-client +org.eclipse.jetty:jetty-continuation +org.eclipse.jetty:jetty-http +org.eclipse.jetty:jetty-io +org.eclipse.jetty:jetty-jndi +org.eclipse.jetty:jetty-plus +org.eclipse.jetty:jetty-proxy +org.eclipse.jetty:jetty-security +org.eclipse.jetty:jetty-server +org.eclipse.jetty:jetty-servlet +org.eclipse.jetty:jetty-servlets +org.eclipse.jetty:jetty-util +org.eclipse.jetty:jetty-webapp +org.eclipse.jetty:jetty-xml + +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses-binary/ +for text of these licenses. + + +BSD 2-Clause +------------ + +com.github.luben:zstd-jni +javolution:javolution +com.esotericsoftware:kryo-shaded +com.esotericsoftware:minlog +com.esotericsoftware:reflectasm +com.google.protobuf:protobuf-java +org.codehaus.janino:commons-compiler +org.codehaus.janino:janino +jline:jline +org.jodd:jodd-core + + +BSD 3-Clause +------------ + +dk.brics.automaton:automaton +org.antlr:antlr-runtime +org.antlr:ST4 +org.antlr:stringtemplate +org.antlr:antlr4-runtime +antlr:antlr +com.github.fommil.netlib:core +com.thoughtworks.paranamer:paranamer +org.scala-lang:scala-compiler +org.scala-lang:scala-library +org.scala-lang:scala-reflect +org.scala-lang.modules:scala-parser-combinators_2.11 +org.scala-lang.modules:scala-xml_2.11 +org.fusesource.leveldbjni:leveldbjni-all +net.sourceforge.f2j:arpack_combined_all +xmlenc:xmlenc +net.sf.py4j:py4j +org.jpmml:pmml-model +org.jpmml:pmml-schema + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +org.spire-math:spire-macros_2.11 +org.spire-math:spire_2.11 +org.typelevel:machinist_2.11 +net.razorvine:pyrolite +org.slf4j:jcl-over-slf4j +org.slf4j:jul-to-slf4j +org.slf4j:slf4j-api +org.slf4j:slf4j-log4j12 +com.github.scopt:scopt_2.11 +org.bouncycastle:bcprov-jdk15on + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Common Development and Distribution License (CDDL) 1.0 +------------------------------------------------------ + +javax.activation:activation http://www.oracle.com/technetwork/java/javase/tech/index-jsp-138795.html +javax.xml.stream:stax-api https://jcp.org/en/jsr/detail?id=173 + + +Common Development and Distribution License (CDDL) 1.1 +------------------------------------------------------ + +javax.annotation:javax.annotation-api https://jcp.org/en/jsr/detail?id=250 +javax.servlet:javax.servlet-api https://javaee.github.io/servlet-spec/ +javax.transaction:jta http://www.oracle.com/technetwork/java/index.html +javax.ws.rs:javax.ws.rs-api https://github.com/jax-rs +javax.xml.bind:jaxb-api https://github.com/javaee/jaxb-v2 +org.glassfish.hk2:hk2-api https://github.com/javaee/glassfish +org.glassfish.hk2:hk2-locator (same) +org.glassfish.hk2:hk2-utils +org.glassfish.hk2:osgi-resource-locator +org.glassfish.hk2.external:aopalliance-repackaged +org.glassfish.hk2.external:javax.inject +org.glassfish.jersey.bundles.repackaged:jersey-guava +org.glassfish.jersey.containers:jersey-container-servlet +org.glassfish.jersey.containers:jersey-container-servlet-core +org.glassfish.jersey.core:jersey-client +org.glassfish.jersey.core:jersey-common +org.glassfish.jersey.core:jersey-server +org.glassfish.jersey.media:jersey-media-jaxb + + +Mozilla Public License (MPL) 1.1 +-------------------------------- + +com.github.rwl:jtransforms https://sourceforge.net/projects/jtransforms/ + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +Public Domain +------------- + +aopalliance:aopalliance +net.iharder:base64 +org.tukaani:xz + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg diff --git a/NOTICE b/NOTICE index 6ec240efbf12e..9246cc54caa3a 100644 --- a/NOTICE +++ b/NOTICE @@ -4,664 +4,3 @@ Copyright 2014 and onwards The Apache Software Foundation. This product includes software developed at The Apache Software Foundation (http://www.apache.org/). - -======================================================================== -Common Development and Distribution License 1.0 -======================================================================== - -The following components are provided under the Common Development and Distribution License 1.0. See project link for details. - - (CDDL 1.0) Glassfish Jasper (org.mortbay.jetty:jsp-2.1:6.1.14 - http://jetty.mortbay.org/project/modules/jsp-2.1) - (CDDL 1.0) JAX-RS (https://jax-rs-spec.java.net/) - (CDDL 1.0) Servlet Specification 2.5 API (org.mortbay.jetty:servlet-api-2.5:6.1.14 - http://jetty.mortbay.org/project/modules/servlet-api-2.5) - (CDDL 1.0) (GPL2 w/ CPE) javax.annotation API (https://glassfish.java.net/nonav/public/CDDL+GPL.html) - (COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL) Version 1.0) (GNU General Public Library) Streaming API for XML (javax.xml.stream:stax-api:1.0-2 - no url defined) - (Common Development and Distribution License (CDDL) v1.0) JavaBeans Activation Framework (JAF) (javax.activation:activation:1.1 - http://java.sun.com/products/javabeans/jaf/index.jsp) - -======================================================================== -Common Development and Distribution License 1.1 -======================================================================== - -The following components are provided under the Common Development and Distribution License 1.1. See project link for details. - - (CDDL 1.1) (GPL2 w/ CPE) org.glassfish.hk2 (https://hk2.java.net) - (CDDL 1.1) (GPL2 w/ CPE) JAXB API bundle for GlassFish V3 (javax.xml.bind:jaxb-api:2.2.2 - https://jaxb.dev.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) JAXB RI (com.sun.xml.bind:jaxb-impl:2.2.3-1 - http://jaxb.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) Jersey 2 (https://jersey.java.net) - -======================================================================== -Common Public License 1.0 -======================================================================== - -The following components are provided under the Common Public 1.0 License. See project link for details. - - (Common Public License Version 1.0) JUnit (junit:junit-dep:4.10 - http://junit.org) - (Common Public License Version 1.0) JUnit (junit:junit:3.8.1 - http://junit.org) - (Common Public License Version 1.0) JUnit (junit:junit:4.8.2 - http://junit.org) - -======================================================================== -Eclipse Public License 1.0 -======================================================================== - -The following components are provided under the Eclipse Public License 1.0. See project link for details. - - (Eclipse Public License v1.0) Eclipse JDT Core (org.eclipse.jdt:core:3.1.1 - http://www.eclipse.org/jdt/) - -======================================================================== -Mozilla Public License 1.0 -======================================================================== - -The following components are provided under the Mozilla Public License 1.0. See project link for details. - - (GPL) (LGPL) (MPL) JTransforms (com.github.rwl:jtransforms:2.4.0 - http://sourceforge.net/projects/jtransforms/) - (Mozilla Public License Version 1.1) jamon-runtime (org.jamon:jamon-runtime:2.3.1 - http://www.jamon.org/jamon-runtime/) - - - -======================================================================== -NOTICE files -======================================================================== - -The following NOTICEs are pertain to software distributed with this project. - - -// ------------------------------------------------------------------ -// NOTICE file corresponding to the section 4d of The Apache License, -// Version 2.0, in this case for -// ------------------------------------------------------------------ - -Apache Avro -Copyright 2009-2013 The Apache Software Foundation - -This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). - -Apache Commons Codec -Copyright 2002-2009 The Apache Software Foundation - -This product includes software developed by -The Apache Software Foundation (http://www.apache.org/). - --------------------------------------------------------------------------------- -src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java contains -test data from http://aspell.sourceforge.net/test/batch0.tab. - -Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org). Verbatim copying -and distribution of this entire article is permitted in any medium, -provided this notice is preserved. --------------------------------------------------------------------------------- - -Apache HttpComponents HttpClient -Copyright 1999-2011 The Apache Software Foundation - -This project contains annotations derived from JCIP-ANNOTATIONS -Copyright (c) 2005 Brian Goetz and Tim Peierls. See http://www.jcip.net - -Apache HttpComponents HttpCore -Copyright 2005-2011 The Apache Software Foundation - -Curator Recipes -Copyright 2011-2014 The Apache Software Foundation - -Curator Framework -Copyright 2011-2014 The Apache Software Foundation - -Curator Client -Copyright 2011-2014 The Apache Software Foundation - -Apache Geronimo -Copyright 2003-2008 The Apache Software Foundation - -Activation 1.1 -Copyright 2003-2007 The Apache Software Foundation - -Apache Commons Lang -Copyright 2001-2014 The Apache Software Foundation - -This product includes software from the Spring Framework, -under the Apache License 2.0 (see: StringUtils.containsWhitespace()) - -Apache log4j -Copyright 2007 The Apache Software Foundation - -# Compress LZF - -This library contains efficient implementation of LZF compression format, -as well as additional helper classes that build on JDK-provided gzip (deflat) -codec. - -## Licensing - -Library is licensed under Apache License 2.0, as per accompanying LICENSE file. - -## Credit - -Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). -It was started at Ning, inc., as an official Open Source process used by -platform backend, but after initial versions has been developed outside of -Ning by supporting community. - -Other contributors include: - -* Jon Hartlaub (first versions of streaming reader/writer; unit tests) -* Cedrik Lime: parallel LZF implementation - -Various community members have contributed bug reports, and suggested minor -fixes; these can be found from file "VERSION.txt" in SCM. - -Objenesis -Copyright 2006-2009 Joe Walnes, Henri Tremblay, Leonardo Mesquita - -Apache Commons Net -Copyright 2001-2010 The Apache Software Foundation - - The Netty Project - ================= - -Please visit the Netty web site for more information: - - * http://netty.io/ - -Copyright 2011 The Netty Project - -The Netty Project licenses this file to you under the Apache License, -version 2.0 (the "License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at: - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -License for the specific language governing permissions and limitations -under the License. - -Also, please refer to each LICENSE..txt file, which is located in -the 'license' directory of the distribution file, for the license terms of the -components that this product depends on. - -------------------------------------------------------------------------------- -This product contains the extensions to Java Collections Framework which has -been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: - - * LICENSE: - * license/LICENSE.jsr166y.txt (Public Domain) - * HOMEPAGE: - * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ - * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ - -This product contains a modified version of Robert Harder's Public Domain -Base64 Encoder and Decoder, which can be obtained at: - - * LICENSE: - * license/LICENSE.base64.txt (Public Domain) - * HOMEPAGE: - * http://iharder.sourceforge.net/current/java/base64/ - -This product contains a modified version of 'JZlib', a re-implementation of -zlib in pure Java, which can be obtained at: - - * LICENSE: - * license/LICENSE.jzlib.txt (BSD Style License) - * HOMEPAGE: - * http://www.jcraft.com/jzlib/ - -This product optionally depends on 'Protocol Buffers', Google's data -interchange format, which can be obtained at: - - * LICENSE: - * license/LICENSE.protobuf.txt (New BSD License) - * HOMEPAGE: - * http://code.google.com/p/protobuf/ - -This product optionally depends on 'SLF4J', a simple logging facade for Java, -which can be obtained at: - - * LICENSE: - * license/LICENSE.slf4j.txt (MIT License) - * HOMEPAGE: - * http://www.slf4j.org/ - -This product optionally depends on 'Apache Commons Logging', a logging -framework, which can be obtained at: - - * LICENSE: - * license/LICENSE.commons-logging.txt (Apache License 2.0) - * HOMEPAGE: - * http://commons.apache.org/logging/ - -This product optionally depends on 'Apache Log4J', a logging framework, -which can be obtained at: - - * LICENSE: - * license/LICENSE.log4j.txt (Apache License 2.0) - * HOMEPAGE: - * http://logging.apache.org/log4j/ - -This product optionally depends on 'JBoss Logging', a logging framework, -which can be obtained at: - - * LICENSE: - * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) - * HOMEPAGE: - * http://anonsvn.jboss.org/repos/common/common-logging-spi/ - -This product optionally depends on 'Apache Felix', an open source OSGi -framework implementation, which can be obtained at: - - * LICENSE: - * license/LICENSE.felix.txt (Apache License 2.0) - * HOMEPAGE: - * http://felix.apache.org/ - -This product optionally depends on 'Webbit', a Java event based -WebSocket and HTTP server: - - * LICENSE: - * license/LICENSE.webbit.txt (BSD License) - * HOMEPAGE: - * https://github.com/joewalnes/webbit - -# Jackson JSON processor - -Jackson is a high-performance, Free/Open Source JSON processing library. -It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has -been in development since 2007. -It is currently developed by a community of developers, as well as supported -commercially by FasterXML.com. - -Jackson core and extension components may be licensed under different licenses. -To find the details that apply to this artifact see the accompanying LICENSE file. -For more information, including possible other licensing options, contact -FasterXML.com (http://fasterxml.com). - -## Credits - -A list of contributors may be found from CREDITS file, which is included -in some artifacts (usually source distributions); but is always available -from the source code management (SCM) system project uses. - -Jackson core and extension components may licensed under different licenses. -To find the details that apply to this artifact see the accompanying LICENSE file. -For more information, including possible other licensing options, contact -FasterXML.com (http://fasterxml.com). - -mesos -Copyright 2014 The Apache Software Foundation - -Apache Thrift -Copyright 2006-2010 The Apache Software Foundation. - - Apache Ant - Copyright 1999-2013 The Apache Software Foundation - - The task is based on code Copyright (c) 2002, Landmark - Graphics Corp that has been kindly donated to the Apache Software - Foundation. - -Apache Commons IO -Copyright 2002-2012 The Apache Software Foundation - -Apache Commons Math -Copyright 2001-2013 The Apache Software Foundation - -=============================================================================== - -The inverse error function implementation in the Erf class is based on CUDA -code developed by Mike Giles, Oxford-Man Institute of Quantitative Finance, -and published in GPU Computing Gems, volume 2, 2010. -=============================================================================== - -The BracketFinder (package org.apache.commons.math3.optimization.univariate) -and PowellOptimizer (package org.apache.commons.math3.optimization.general) -classes are based on the Python code in module "optimize.py" (version 0.5) -developed by Travis E. Oliphant for the SciPy library (http://www.scipy.org/) -Copyright © 2003-2009 SciPy Developers. -=============================================================================== - -The LinearConstraint, LinearObjectiveFunction, LinearOptimizer, -RelationShip, SimplexSolver and SimplexTableau classes in package -org.apache.commons.math3.optimization.linear include software developed by -Benjamin McCann (http://www.benmccann.com) and distributed with -the following copyright: Copyright 2009 Google Inc. -=============================================================================== - -This product includes software developed by the -University of Chicago, as Operator of Argonne National -Laboratory. -The LevenbergMarquardtOptimizer class in package -org.apache.commons.math3.optimization.general includes software -translated from the lmder, lmpar and qrsolv Fortran routines -from the Minpack package -Minpack Copyright Notice (1999) University of Chicago. All rights reserved -=============================================================================== - -The GraggBulirschStoerIntegrator class in package -org.apache.commons.math3.ode.nonstiff includes software translated -from the odex Fortran routine developed by E. Hairer and G. Wanner. -Original source copyright: -Copyright (c) 2004, Ernst Hairer -=============================================================================== - -The EigenDecompositionImpl class in package -org.apache.commons.math3.linear includes software translated -from some LAPACK Fortran routines. Original source copyright: -Copyright (c) 1992-2008 The University of Tennessee. All rights reserved. -=============================================================================== - -The MersenneTwister class in package org.apache.commons.math3.random -includes software translated from the 2002-01-26 version of -the Mersenne-Twister generator written in C by Makoto Matsumoto and Takuji -Nishimura. Original source copyright: -Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, -All rights reserved -=============================================================================== - -The LocalizedFormatsTest class in the unit tests is an adapted version of -the OrekitMessagesTest class from the orekit library distributed under the -terms of the Apache 2 licence. Original source copyright: -Copyright 2010 CS Systèmes d'Information -=============================================================================== - -The HermiteInterpolator class and its corresponding test have been imported from -the orekit library distributed under the terms of the Apache 2 licence. Original -source copyright: -Copyright 2010-2012 CS Systèmes d'Information -=============================================================================== - -The creation of the package "o.a.c.m.analysis.integration.gauss" was inspired -by an original code donated by Sébastien Brisard. -=============================================================================== - -The complete text of licenses and disclaimers associated with the the original -sources enumerated above at the time of code translation are in the LICENSE.txt -file. - -This product currently only contains code developed by authors -of specific components, as identified by the source code files; -if such notes are missing files have been created by -Tatu Saloranta. - -For additional credits (generally to people who reported problems) -see CREDITS file. - -Apache Commons Lang -Copyright 2001-2011 The Apache Software Foundation - -Apache Commons Compress -Copyright 2002-2012 The Apache Software Foundation - -Apache Commons CLI -Copyright 2001-2009 The Apache Software Foundation - -Google Guice - Extensions - Servlet -Copyright 2006-2011 Google, Inc. - -Google Guice - Core Library -Copyright 2006-2011 Google, Inc. - -Apache Jakarta HttpClient -Copyright 1999-2007 The Apache Software Foundation - -Apache Hive -Copyright 2008-2013 The Apache Software Foundation - -This product includes software developed by The Apache Software -Foundation (http://www.apache.org/). - -This product includes software developed by The JDBM Project -(http://jdbm.sourceforge.net/). - -This product includes/uses ANTLR (http://www.antlr.org/), -Copyright (c) 2003-2011, Terrence Parr. - -This product includes/uses StringTemplate (http://www.stringtemplate.org/), -Copyright (c) 2011, Terrence Parr. - -This product includes/uses ASM (http://asm.ow2.org/), -Copyright (c) 2000-2007 INRIA, France Telecom. - -This product includes/uses JLine (http://jline.sourceforge.net/), -Copyright (c) 2002-2006, Marc Prud'hommeaux . - -This product includes/uses SQLLine (http://sqlline.sourceforge.net), -Copyright (c) 2002, 2003, 2004, 2005 Marc Prud'hommeaux . - -This product includes/uses SLF4J (http://www.slf4j.org/), -Copyright (c) 2004-2010 QOS.ch - -This product includes/uses Bootstrap (http://twitter.github.com/bootstrap/), -Copyright (c) 2012 Twitter, Inc. - -This product includes/uses Glyphicons (http://glyphicons.com/), -Copyright (c) 2010 - 2012 Jan Kovarík - -This product includes DataNucleus (http://www.datanucleus.org/) -Copyright 2008-2008 DataNucleus - -This product includes Guava (http://code.google.com/p/guava-libraries/) -Copyright (C) 2006 Google Inc. - -This product includes JavaEWAH (http://code.google.com/p/javaewah/) -Copyright (C) 2011 Google Inc. - -Apache Commons Pool -Copyright 1999-2009 The Apache Software Foundation - -This product includes/uses Kubernetes & OpenShift 3 Java Client (https://github.com/fabric8io/kubernetes-client) -Copyright (C) 2015 Red Hat, Inc. - -This product includes/uses OkHttp (https://github.com/square/okhttp) -Copyright (C) 2012 The Android Open Source Project - -========================================================================= -== NOTICE file corresponding to section 4(d) of the Apache License, == -== Version 2.0, in this case for the DataNucleus distribution. == -========================================================================= - -=================================================================== -This product includes software developed by many individuals, -including the following: -=================================================================== -Erik Bengtson -Andy Jefferson - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== -Joerg von Frantzius -Thomas Marti -Barry Haddow -Marco Schulze -Ralph Ullrich -David Ezzio -Brendan de Beer -David Eaves -Martin Taal -Tony Lai -Roland Szabo -Marcus Mennemeier -Xuan Baldauf -Eric Sultan - -=================================================================== -This product also includes software developed by the TJDO project -(http://tjdo.sourceforge.net/). -=================================================================== - -=================================================================== -This product includes software developed by many individuals, -including the following: -=================================================================== -Andy Jefferson -Erik Bengtson -Joerg von Frantzius -Marco Schulze - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== -Barry Haddow -Ralph Ullrich -David Ezzio -Brendan de Beer -David Eaves -Martin Taal -Tony Lai -Roland Szabo -Anton Troshin (Timesten) - -=================================================================== -This product also includes software developed by the Apache Commons project -(http://commons.apache.org/). -=================================================================== - -Apache Java Data Objects (JDO) -Copyright 2005-2006 The Apache Software Foundation - -========================================================================= -== NOTICE file corresponding to section 4(d) of the Apache License, == -== Version 2.0, in this case for the Apache Derby distribution. == -========================================================================= - -Apache Derby -Copyright 2004-2008 The Apache Software Foundation - -Portions of Derby were originally developed by -International Business Machines Corporation and are -licensed to the Apache Software Foundation under the -"Software Grant and Corporate Contribution License Agreement", -informally known as the "Derby CLA". -The following copyright notice(s) were affixed to portions of the code -with which this file is now or was at one time distributed -and are placed here unaltered. - -(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. - -(C) Copyright IBM Corp. 2003. - -The portion of the functionTests under 'nist' was originally -developed by the National Institute of Standards and Technology (NIST), -an agency of the United States Department of Commerce, and adapted by -International Business Machines Corporation in accordance with the NIST -Software Acknowledgment and Redistribution document at -http://www.itl.nist.gov/div897/ctg/sql_form.htm - -Apache Commons Collections -Copyright 2001-2008 The Apache Software Foundation - -Apache Commons Configuration -Copyright 2001-2008 The Apache Software Foundation - -Apache Jakarta Commons Digester -Copyright 2001-2006 The Apache Software Foundation - -Apache Commons BeanUtils -Copyright 2000-2008 The Apache Software Foundation - -Apache Avro Mapred API -Copyright 2009-2013 The Apache Software Foundation - -Apache Avro IPC -Copyright 2009-2013 The Apache Software Foundation - - -Vis.js -Copyright 2010-2015 Almende B.V. - -Vis.js is dual licensed under both - - * The Apache 2.0 License - http://www.apache.org/licenses/LICENSE-2.0 - - and - - * The MIT License - http://opensource.org/licenses/MIT - -Vis.js may be distributed under either license. - - -Vis.js uses and redistributes the following third-party libraries: - -- component-emitter - https://github.com/component/emitter - The MIT License - -- hammer.js - http://hammerjs.github.io/ - The MIT License - -- moment.js - http://momentjs.com/ - The MIT License - -- keycharm - https://github.com/AlexDM0/keycharm - The MIT License - -=============================================================================== - -The CSS style for the navigation sidebar of the documentation was originally -submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project -is distributed under the 3-Clause BSD license. -=============================================================================== - -For CSV functionality: - -/* - * Copyright 2014 Databricks - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Copyright 2015 Ayasdi Inc - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -=============================================================================== -For dev/sparktestsupport/toposort.py: - -Copyright 2014 True Blade Systems, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/NOTICE-binary b/NOTICE-binary new file mode 100644 index 0000000000000..d56f99bdb55a6 --- /dev/null +++ b/NOTICE-binary @@ -0,0 +1,1170 @@ +Apache Spark +Copyright 2014 and onwards The Apache Software Foundation. + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +// ------------------------------------------------------------------ +// NOTICE file corresponding to the section 4d of The Apache License, +// Version 2.0, in this case for +// ------------------------------------------------------------------ + +Hive Beeline +Copyright 2016 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro +Copyright 2009-2014 The Apache Software Foundation + +This product currently only contains code developed by authors +of specific components, as identified by the source code files; +if such notes are missing files have been created by +Tatu Saloranta. + +For additional credits (generally to people who reported problems) +see CREDITS file. + +Apache Commons Compress +Copyright 2002-2012 The Apache Software Foundation + +This product includes software developed by +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro Mapred API +Copyright 2009-2014 The Apache Software Foundation + +Apache Avro IPC +Copyright 2009-2014 The Apache Software Foundation + +Objenesis +Copyright 2006-2013 Joe Walnes, Henri Tremblay, Leonardo Mesquita + +Apache XBean :: ASM 5 shaded (repackaged) +Copyright 2005-2015 The Apache Software Foundation + +-------------------------------------- + +This product includes software developed at +OW2 Consortium (http://asm.ow2.org/) + +This product includes software developed by The Apache Software +Foundation (http://www.apache.org/). + +The binary distribution of this product bundles binaries of +org.iq80.leveldb:leveldb-api (https://github.com/dain/leveldb), which has the +following notices: +* Copyright 2011 Dain Sundstrom +* Copyright 2011 FuseSource Corp. http://fusesource.com + +The binary distribution of this product bundles binaries of +org.fusesource.hawtjni:hawtjni-runtime (https://github.com/fusesource/hawtjni), +which has the following notices: +* This product includes software developed by FuseSource Corp. + http://fusesource.com +* This product includes software developed at + Progress Software Corporation and/or its subsidiaries or affiliates. +* This product includes software developed by IBM Corporation and others. + +The binary distribution of this product bundles binaries of +Gson 2.2.4, +which has the following notices: + + The Netty Project + ================= + +Please visit the Netty web site for more information: + + * http://netty.io/ + +Copyright 2014 The Netty Project + +The Netty Project licenses this file to you under the Apache License, +version 2.0 (the "License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +License for the specific language governing permissions and limitations +under the License. + +Also, please refer to each LICENSE..txt file, which is located in +the 'license' directory of the distribution file, for the license terms of the +components that this product depends on. + +------------------------------------------------------------------------------- +This product contains the extensions to Java Collections Framework which has +been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: + + * LICENSE: + * license/LICENSE.jsr166y.txt (Public Domain) + * HOMEPAGE: + * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ + * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ + +This product contains a modified version of Robert Harder's Public Domain +Base64 Encoder and Decoder, which can be obtained at: + + * LICENSE: + * license/LICENSE.base64.txt (Public Domain) + * HOMEPAGE: + * http://iharder.sourceforge.net/current/java/base64/ + +This product contains a modified portion of 'Webbit', an event based +WebSocket and HTTP server, which can be obtained at: + + * LICENSE: + * license/LICENSE.webbit.txt (BSD License) + * HOMEPAGE: + * https://github.com/joewalnes/webbit + +This product contains a modified portion of 'SLF4J', a simple logging +facade for Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.slf4j.txt (MIT License) + * HOMEPAGE: + * http://www.slf4j.org/ + +This product contains a modified portion of 'ArrayDeque', written by Josh +Bloch of Google, Inc: + + * LICENSE: + * license/LICENSE.deque.txt (Public Domain) + +This product contains a modified portion of 'Apache Harmony', an open source +Java SE, which can be obtained at: + + * LICENSE: + * license/LICENSE.harmony.txt (Apache License 2.0) + * HOMEPAGE: + * http://archive.apache.org/dist/harmony/ + +This product contains a modified version of Roland Kuhn's ASL2 +AbstractNodeQueue, which is based on Dmitriy Vyukov's non-intrusive MPSC queue. +It can be obtained at: + + * LICENSE: + * license/LICENSE.abstractnodequeue.txt (Public Domain) + * HOMEPAGE: + * https://github.com/akka/akka/blob/wip-2.2.3-for-scala-2.11/akka-actor/src/main/java/akka/dispatch/AbstractNodeQueue.java + +This product contains a modified portion of 'jbzip2', a Java bzip2 compression +and decompression library written by Matthew J. Francis. It can be obtained at: + + * LICENSE: + * license/LICENSE.jbzip2.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jbzip2/ + +This product contains a modified portion of 'libdivsufsort', a C API library to construct +the suffix array and the Burrows-Wheeler transformed string for any input string of +a constant-size alphabet written by Yuta Mori. It can be obtained at: + + * LICENSE: + * license/LICENSE.libdivsufsort.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/libdivsufsort/ + +This product contains a modified portion of Nitsan Wakart's 'JCTools', Java Concurrency Tools for the JVM, + which can be obtained at: + + * LICENSE: + * license/LICENSE.jctools.txt (ASL2 License) + * HOMEPAGE: + * https://github.com/JCTools/JCTools + +This product optionally depends on 'JZlib', a re-implementation of zlib in +pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product optionally depends on 'Compress-LZF', a Java library for encoding and +decoding data in LZF format, written by Tatu Saloranta. It can be obtained at: + + * LICENSE: + * license/LICENSE.compress-lzf.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/ning/compress + +This product optionally depends on 'lz4', a LZ4 Java compression +and decompression library written by Adrien Grand. It can be obtained at: + + * LICENSE: + * license/LICENSE.lz4.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jpountz/lz4-java + +This product optionally depends on 'lzma-java', a LZMA Java compression +and decompression library, which can be obtained at: + + * LICENSE: + * license/LICENSE.lzma-java.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jponge/lzma-java + +This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression +and decompression library written by William Kinney. It can be obtained at: + + * LICENSE: + * license/LICENSE.jfastlz.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jfastlz/ + +This product contains a modified portion of and optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + + * LICENSE: + * license/LICENSE.protobuf.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/protobuf/ + +This product optionally depends on 'Bouncy Castle Crypto APIs' to generate +a temporary self-signed X.509 certificate when the JVM does not provide the +equivalent functionality. It can be obtained at: + + * LICENSE: + * license/LICENSE.bouncycastle.txt (MIT License) + * HOMEPAGE: + * http://www.bouncycastle.org/ + +This product optionally depends on 'Snappy', a compression library produced +by Google Inc, which can be obtained at: + + * LICENSE: + * license/LICENSE.snappy.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/snappy/ + +This product optionally depends on 'JBoss Marshalling', an alternative Java +serialization API, which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-marshalling.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://www.jboss.org/jbossmarshalling + +This product optionally depends on 'Caliper', Google's micro- +benchmarking framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.caliper.txt (Apache License 2.0) + * HOMEPAGE: + * http://code.google.com/p/caliper/ + +This product optionally depends on 'Apache Commons Logging', a logging +framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-logging.txt (Apache License 2.0) + * HOMEPAGE: + * http://commons.apache.org/logging/ + +This product optionally depends on 'Apache Log4J', a logging framework, which +can be obtained at: + + * LICENSE: + * license/LICENSE.log4j.txt (Apache License 2.0) + * HOMEPAGE: + * http://logging.apache.org/log4j/ + +This product optionally depends on 'Aalto XML', an ultra-high performance +non-blocking XML processor, which can be obtained at: + + * LICENSE: + * license/LICENSE.aalto-xml.txt (Apache License 2.0) + * HOMEPAGE: + * http://wiki.fasterxml.com/AaltoHome + +This product contains a modified version of 'HPACK', a Java implementation of +the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at: + + * LICENSE: + * license/LICENSE.hpack.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/twitter/hpack + +This product contains a modified portion of 'Apache Commons Lang', a Java library +provides utilities for the java.lang API, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-lang.txt (Apache License 2.0) + * HOMEPAGE: + * https://commons.apache.org/proper/commons-lang/ + +The binary distribution of this product bundles binaries of +Commons Codec 1.4, +which has the following notices: + * src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.javacontains test data from http://aspell.net/test/orig/batch0.tab.Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + =============================================================================== + The content of package org.apache.commons.codec.language.bm has been translated + from the original php source code available at http://stevemorse.org/phoneticinfo.htm + with permission from the original authors. + Original source copyright:Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +The binary distribution of this product bundles binaries of +Commons Lang 2.6, +which has the following notices: + * This product includes software from the Spring Framework,under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +The binary distribution of this product bundles binaries of +Apache Log4j 1.2.17, +which has the following notices: + * ResolverUtil.java + Copyright 2005-2006 Tim Fennell + Dumbster SMTP test server + Copyright 2004 Jason Paul Kitchen + TypeUtil.java + Copyright 2002-2012 Ramnivas Laddad, Juergen Hoeller, Chris Beams + +The binary distribution of this product bundles binaries of +Jetty 6.1.26, +which has the following notices: + * ============================================================== + Jetty Web Container + Copyright 1995-2016 Mort Bay Consulting Pty Ltd. + ============================================================== + + The Jetty Web Container is Copyright Mort Bay Consulting Pty Ltd + unless otherwise noted. + + Jetty is dual licensed under both + + * The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0.html + + and + + * The Eclipse Public 1.0 License + http://www.eclipse.org/legal/epl-v10.html + + Jetty may be distributed under either license. + + ------ + Eclipse + + The following artifacts are EPL. + * org.eclipse.jetty.orbit:org.eclipse.jdt.core + + The following artifacts are EPL and ASL2. + * org.eclipse.jetty.orbit:javax.security.auth.message + + The following artifacts are EPL and CDDL 1.0. + * org.eclipse.jetty.orbit:javax.mail.glassfish + + ------ + Oracle + + The following artifacts are CDDL + GPLv2 with classpath exception. + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + * javax.servlet:javax.servlet-api + * javax.annotation:javax.annotation-api + * javax.transaction:javax.transaction-api + * javax.websocket:javax.websocket-api + + ------ + Oracle OpenJDK + + If ALPN is used to negotiate HTTP/2 connections, then the following + artifacts may be included in the distribution or downloaded when ALPN + module is selected. + + * java.sun.security.ssl + + These artifacts replace/modify OpenJDK classes. The modififications + are hosted at github and both modified and original are under GPL v2 with + classpath exceptions. + http://openjdk.java.net/legal/gplv2+ce.html + + ------ + OW2 + + The following artifacts are licensed by the OW2 Foundation according to the + terms of http://asm.ow2.org/license.html + + org.ow2.asm:asm-commons + org.ow2.asm:asm + + ------ + Apache + + The following artifacts are ASL2 licensed. + + org.apache.taglibs:taglibs-standard-spec + org.apache.taglibs:taglibs-standard-impl + + ------ + MortBay + + The following artifacts are ASL2 licensed. Based on selected classes from + following Apache Tomcat jars, all ASL2 licensed. + + org.mortbay.jasper:apache-jsp + org.apache.tomcat:tomcat-jasper + org.apache.tomcat:tomcat-juli + org.apache.tomcat:tomcat-jsp-api + org.apache.tomcat:tomcat-el-api + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-api + org.apache.tomcat:tomcat-util-scan + org.apache.tomcat:tomcat-util + + org.mortbay.jasper:apache-el + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-el-api + + ------ + Mortbay + + The following artifacts are CDDL + GPLv2 with classpath exception. + + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + org.eclipse.jetty.toolchain:jetty-schemas + + ------ + Assorted + + The UnixCrypt.java code implements the one way cryptography used by + Unix systems for simple password protection. Copyright 1996 Aki Yoshida, + modified April 2001 by Iris Van den Broeke, Daniel Deville. + Permission to use, copy, modify and distribute UnixCrypt + for non-commercial or commercial purposes and without fee is + granted provided that the copyright notice appears in all copies./ + +The binary distribution of this product bundles binaries of +Snappy for Java 1.0.4.1, +which has the following notices: + * This product includes software developed by Google + Snappy: http://code.google.com/p/snappy/ (New BSD License) + + This product includes software developed by Apache + PureJavaCrc32C from apache-hadoop-common http://hadoop.apache.org/ + (Apache 2.0 license) + + This library containd statically linked libstdc++. This inclusion is allowed by + "GCC RUntime Library Exception" + http://gcc.gnu.org/onlinedocs/libstdc++/manual/license.html + + == Contributors == + * Tatu Saloranta + * Providing benchmark suite + * Alec Wysoker + * Performance and memory usage improvement + +The binary distribution of this product bundles binaries of +Xerces2 Java Parser 2.9.1, +which has the following notices: + * ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Apache Xerces Java + Copyright 1999-2007 The Apache Software Foundation + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +Apache Commons Collections +Copyright 2001-2015 The Apache Software Foundation + +Apache Commons Configuration +Copyright 2001-2008 The Apache Software Foundation + +Apache Jakarta Commons Digester +Copyright 2001-2006 The Apache Software Foundation + +Apache Commons BeanUtils +Copyright 2000-2008 The Apache Software Foundation + +ApacheDS Protocol Kerberos Codec +Copyright 2003-2013 The Apache Software Foundation + +ApacheDS I18n +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory API ASN.1 API +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory LDAP API Utilities +Copyright 2003-2013 The Apache Software Foundation + +Curator Client +Copyright 2011-2015 The Apache Software Foundation + +htrace-core +Copyright 2015 The Apache Software Foundation + + ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +# Jackson JSON processor + +Jackson is a high-performance, Free/Open Source JSON processing library. +It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has +been in development since 2007. +It is currently developed by a community of developers, as well as supported +commercially by FasterXML.com. + +## Licensing + +Jackson core and extension components may licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +## Credits + +A list of contributors may be found from CREDITS file, which is included +in some artifacts (usually source distributions); but is always available +from the source code management (SCM) system project uses. + +Apache HttpCore +Copyright 2005-2017 The Apache Software Foundation + +Curator Recipes +Copyright 2011-2015 The Apache Software Foundation + +Curator Framework +Copyright 2011-2015 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2016 The Apache Software Foundation + +This product includes software from the Spring Framework, +under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +Apache Commons Math +Copyright 2001-2015 The Apache Software Foundation + +This product includes software developed for Orekit by +CS Systèmes d'Information (http://www.c-s.fr/) +Copyright 2010-2012 CS Systèmes d'Information + +Apache log4j +Copyright 2007 The Apache Software Foundation + +# Compress LZF + +This library contains efficient implementation of LZF compression format, +as well as additional helper classes that build on JDK-provided gzip (deflat) +codec. + +Library is licensed under Apache License 2.0, as per accompanying LICENSE file. + +## Credit + +Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). +It was started at Ning, inc., as an official Open Source process used by +platform backend, but after initial versions has been developed outside of +Ning by supporting community. + +Other contributors include: + +* Jon Hartlaub (first versions of streaming reader/writer; unit tests) +* Cedrik Lime: parallel LZF implementation + +Various community members have contributed bug reports, and suggested minor +fixes; these can be found from file "VERSION.txt" in SCM. + +Apache Commons Net +Copyright 2001-2012 The Apache Software Foundation + +Copyright 2011 The Netty Project + +http://www.apache.org/licenses/LICENSE-2.0 + +This product contains a modified version of 'JZlib', a re-implementation of +zlib in pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD Style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product contains a modified version of 'Webbit', a Java event based +WebSocket and HTTP server: + +This product optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + +This product optionally depends on 'SLF4J', a simple logging facade for Java, +which can be obtained at: + +This product optionally depends on 'Apache Log4J', a logging framework, +which can be obtained at: + +This product optionally depends on 'JBoss Logging', a logging framework, +which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://anonsvn.jboss.org/repos/common/common-logging-spi/ + +This product optionally depends on 'Apache Felix', an open source OSGi +framework implementation, which can be obtained at: + + * LICENSE: + * license/LICENSE.felix.txt (Apache License 2.0) + * HOMEPAGE: + * http://felix.apache.org/ + +Jackson core and extension components may be licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +Apache Ivy (TM) +Copyright 2007-2014 The Apache Software Foundation + +Portions of Ivy were originally developed at +Jayasoft SARL (http://www.jayasoft.fr/) +and are licensed to the Apache Software Foundation under the +"Software Grant License Agreement" + +SSH and SFTP support is provided by the JCraft JSch package, +which is open source software, available under +the terms of a BSD style license. +The original software and related information is available +at http://www.jcraft.com/jsch/. + + +ORC Core +Copyright 2013-2018 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2011 The Apache Software Foundation + +ORC MapReduce +Copyright 2013-2018 The Apache Software Foundation + +Apache Parquet Format +Copyright 2017 The Apache Software Foundation + +Arrow Vectors +Copyright 2017 The Apache Software Foundation + +Arrow Format +Copyright 2017 The Apache Software Foundation + +Arrow Memory +Copyright 2017 The Apache Software Foundation + +Apache Commons CLI +Copyright 2001-2009 The Apache Software Foundation + +Google Guice - Extensions - Servlet +Copyright 2006-2011 Google, Inc. + +Apache Commons IO +Copyright 2002-2012 The Apache Software Foundation + +Google Guice - Core Library +Copyright 2006-2011 Google, Inc. + +mesos +Copyright 2017 The Apache Software Foundation + +Apache Parquet Hadoop Bundle (Incubating) +Copyright 2015 The Apache Software Foundation + +Hive Query Language +Copyright 2016 The Apache Software Foundation + +Apache Extras Companion for log4j 1.2. +Copyright 2007 The Apache Software Foundation + +Hive Metastore +Copyright 2016 The Apache Software Foundation + +Apache Commons Logging +Copyright 2003-2013 The Apache Software Foundation + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, == +== Version 2.0, in this case for the DataNucleus distribution. == +========================================================================= + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Erik Bengtson +Andy Jefferson + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Andy Jefferson +Erik Bengtson +Joerg von Frantzius +Marco Schulze + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Barry Haddow +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Anton Troshin (Timesten) + +=================================================================== +This product also includes software developed by the TJDO project +(http://tjdo.sourceforge.net/). +=================================================================== + +=================================================================== +This product also includes software developed by the Apache Commons project +(http://commons.apache.org/). +=================================================================== + +Apache Commons Pool +Copyright 1999-2009 The Apache Software Foundation + +Apache Commons DBCP +Copyright 2001-2010 The Apache Software Foundation + +Apache Java Data Objects (JDO) +Copyright 2005-2006 The Apache Software Foundation + +Apache Jakarta HttpClient +Copyright 1999-2007 The Apache Software Foundation + +Calcite Avatica +Copyright 2012-2015 The Apache Software Foundation + +Calcite Core +Copyright 2012-2015 The Apache Software Foundation + +Calcite Linq4j +Copyright 2012-2015 The Apache Software Foundation + +Apache HttpClient +Copyright 1999-2017 The Apache Software Foundation + +Apache Commons Codec +Copyright 2002-2014 The Apache Software Foundation + +src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java +contains test data from http://aspell.net/test/orig/batch0.tab. +Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + +=============================================================================== + +The content of package org.apache.commons.codec.language.bm has been translated +from the original php source code available at http://stevemorse.org/phoneticinfo.htm +with permission from the original authors. +Original source copyright: +Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +============================================================================= += NOTICE file corresponding to section 4d of the Apache License Version 2.0 = +============================================================================= +This product includes software developed by +Joda.org (http://www.joda.org/). + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Joerg von Frantzius +Thomas Marti +Barry Haddow +Marco Schulze +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Marcus Mennemeier +Xuan Baldauf +Eric Sultan + +Apache Thrift +Copyright 2006-2010 The Apache Software Foundation. + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, +== Version 2.0, in this case for the Apache Derby distribution. +== +== DO NOT EDIT THIS FILE DIRECTLY. IT IS GENERATED +== BY THE buildnotice TARGET IN THE TOP LEVEL build.xml FILE. +== +========================================================================= + +Apache Derby +Copyright 2004-2015 The Apache Software Foundation + +========================================================================= + +Portions of Derby were originally developed by +International Business Machines Corporation and are +licensed to the Apache Software Foundation under the +"Software Grant and Corporate Contribution License Agreement", +informally known as the "Derby CLA". +The following copyright notice(s) were affixed to portions of the code +with which this file is now or was at one time distributed +and are placed here unaltered. + +(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. + +(C) Copyright IBM Corp. 2003. + +The portion of the functionTests under 'nist' was originally +developed by the National Institute of Standards and Technology (NIST), +an agency of the United States Department of Commerce, and adapted by +International Business Machines Corporation in accordance with the NIST +Software Acknowledgment and Redistribution document at +http://www.itl.nist.gov/div897/ctg/sql_form.htm + +The JDBC apis for small devices and JDBC3 (under java/stubs/jsr169 and +java/stubs/jdbc3) were produced by trimming sources supplied by the +Apache Harmony project. In addition, the Harmony SerialBlob and +SerialClob implementations are used. The following notice covers the Harmony sources: + +Portions of Harmony were originally developed by +Intel Corporation and are licensed to the Apache Software +Foundation under the "Software Grant and Corporate Contribution +License Agreement", informally known as the "Intel Harmony CLA". + +The Derby build relies on source files supplied by the Apache Felix +project. The following notice covers the Felix files: + + Apache Felix Main + Copyright 2008 The Apache Software Foundation + + I. Included Software + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + Licensed under the Apache License 2.0. + + This product includes software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + This product includes software from http://kxml.sourceforge.net. + Copyright (c) 2002,2003, Stefan Haustein, Oberhausen, Rhld., Germany. + Licensed under BSD License. + + II. Used Software + + This product uses software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + III. License Summary + - Apache License 2.0 + - BSD License + +The Derby build relies on jar files supplied by the Apache Lucene +project. The following notice covers the Lucene files: + +Apache Lucene +Copyright 2013 The Apache Software Foundation + +Includes software from other Apache Software Foundation projects, +including, but not limited to: + - Apache Ant + - Apache Jakarta Regexp + - Apache Commons + - Apache Xerces + +ICU4J, (under analysis/icu) is licensed under an MIT styles license +and Copyright (c) 1995-2008 International Business Machines Corporation and others + +Some data files (under analysis/icu/src/data) are derived from Unicode data such +as the Unicode Character Database. See http://unicode.org/copyright.html for more +details. + +Brics Automaton (under core/src/java/org/apache/lucene/util/automaton) is +BSD-licensed, created by Anders Møller. See http://www.brics.dk/automaton/ + +The levenshtein automata tables (under core/src/java/org/apache/lucene/util/automaton) were +automatically generated with the moman/finenight FSA library, created by +Jean-Philippe Barrette-LaPierre. This library is available under an MIT license, +see http://sites.google.com/site/rrettesite/moman and +http://bitbucket.org/jpbarrette/moman/overview/ + +The class org.apache.lucene.util.WeakIdentityMap was derived from +the Apache CXF project and is Apache License 2.0. + +The Google Code Prettify is Apache License 2.0. +See http://code.google.com/p/google-code-prettify/ + +JUnit (junit-4.10) is licensed under the Common Public License v. 1.0 +See http://junit.sourceforge.net/cpl-v10.html + +This product includes code (JaspellTernarySearchTrie) from Java Spelling Checkin +g Package (jaspell): http://jaspell.sourceforge.net/ +License: The BSD License (http://www.opensource.org/licenses/bsd-license.php) + +The snowball stemmers in + analysis/common/src/java/net/sf/snowball +were developed by Martin Porter and Richard Boulton. +The snowball stopword lists in + analysis/common/src/resources/org/apache/lucene/analysis/snowball +were developed by Martin Porter and Richard Boulton. +The full snowball package is available from + http://snowball.tartarus.org/ + +The KStem stemmer in + analysis/common/src/org/apache/lucene/analysis/en +was developed by Bob Krovetz and Sergio Guzman-Lara (CIIR-UMass Amherst) +under the BSD-license. + +The Arabic,Persian,Romanian,Bulgarian, and Hindi analyzers (common) come with a default +stopword list that is BSD-licensed created by Jacques Savoy. These files reside in: +analysis/common/src/resources/org/apache/lucene/analysis/ar/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/fa/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/ro/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/bg/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/hi/stopwords.txt +See http://members.unine.ch/jacques.savoy/clef/index.html. + +The German,Spanish,Finnish,French,Hungarian,Italian,Portuguese,Russian and Swedish light stemmers +(common) are based on BSD-licensed reference implementations created by Jacques Savoy and +Ljiljana Dolamic. These files reside in: +analysis/common/src/java/org/apache/lucene/analysis/de/GermanLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/de/GermanMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/es/SpanishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fi/FinnishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/hu/HungarianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/it/ItalianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/pt/PortugueseLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/ru/RussianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/sv/SwedishLightStemmer.java + +The Stempel analyzer (stempel) includes BSD-licensed software developed +by the Egothor project http://egothor.sf.net/, created by Leo Galambos, Martin Kvapil, +and Edmond Nolan. + +The Polish analyzer (stempel) comes with a default +stopword list that is BSD-licensed created by the Carrot2 project. The file resides +in stempel/src/resources/org/apache/lucene/analysis/pl/stopwords.txt. +See http://project.carrot2.org/license.html. + +The SmartChineseAnalyzer source code (smartcn) was +provided by Xiaoping Gao and copyright 2009 by www.imdict.net. + +WordBreakTestUnicode_*.java (under modules/analysis/common/src/test/) +is derived from Unicode data such as the Unicode Character Database. +See http://unicode.org/copyright.html for more details. + +The Morfologik analyzer (morfologik) includes BSD-licensed software +developed by Dawid Weiss and Marcin Miłkowski (http://morfologik.blogspot.com/). + +Morfologik uses data from Polish ispell/myspell dictionary +(http://www.sjp.pl/slownik/en/) licenced on the terms of (inter alia) +LGPL and Creative Commons ShareAlike. + +Morfologic includes data from BSD-licensed dictionary of Polish (SGJP) +(http://sgjp.pl/morfeusz/) + +Servlet-api.jar and javax.servlet-*.jar are under the CDDL license, the original +source code for this can be found at http://www.eclipse.org/jetty/downloads.php + +=========================================================================== +Kuromoji Japanese Morphological Analyzer - Apache Lucene Integration +=========================================================================== + +This software includes a binary and/or source version of data from + + mecab-ipadic-2.7.0-20070801 + +which can be obtained from + + http://atilika.com/releases/mecab-ipadic/mecab-ipadic-2.7.0-20070801.tar.gz + +or + + http://jaist.dl.sourceforge.net/project/mecab/mecab-ipadic/2.7.0-20070801/mecab-ipadic-2.7.0-20070801.tar.gz + +=========================================================================== +mecab-ipadic-2.7.0-20070801 Notice +=========================================================================== + +Nara Institute of Science and Technology (NAIST), +the copyright holders, disclaims all warranties with regard to this +software, including all implied warranties of merchantability and +fitness, in no event shall NAIST be liable for +any special, indirect or consequential damages or any damages +whatsoever resulting from loss of use, data or profits, whether in an +action of contract, negligence or other tortuous action, arising out +of or in connection with the use or performance of this software. + +A large portion of the dictionary entries +originate from ICOT Free Software. The following conditions for ICOT +Free Software applies to the current dictionary as well. + +Each User may also freely distribute the Program, whether in its +original form or modified, to any third party or parties, PROVIDED +that the provisions of Section 3 ("NO WARRANTY") will ALWAYS appear +on, or be attached to, the Program, which is distributed substantially +in the same form as set out herein and that such intended +distribution, if actually made, will neither violate or otherwise +contravene any of the laws and regulations of the countries having +jurisdiction over the User or the intended distribution itself. + +NO WARRANTY + +The program was produced on an experimental basis in the course of the +research and development conducted during the project and is provided +to users as so produced on an experimental basis. Accordingly, the +program is provided without any warranty whatsoever, whether express, +implied, statutory or otherwise. The term "warranty" used herein +includes, but is not limited to, any warranty of the quality, +performance, merchantability and fitness for a particular purpose of +the program and the nonexistence of any infringement or violation of +any right of any third party. + +Each user of the program will agree and understand, and be deemed to +have agreed and understood, that there is no warranty whatsoever for +the program and, accordingly, the entire risk arising from or +otherwise connected with the program is assumed by the user. + +Therefore, neither ICOT, the copyright holder, or any other +organization that participated in or was otherwise related to the +development of the program and their respective officials, directors, +officers and other employees shall be held liable for any and all +damages, including, without limitation, general, special, incidental +and consequential damages, arising out of or otherwise in connection +with the use or inability to use the program or any product, material +or result produced or otherwise obtained by using the program, +regardless of whether they have been advised of, or otherwise had +knowledge of, the possibility of such damages at any time during the +project or thereafter. Each user will be deemed to have agreed to the +foregoing by his or her commencement of use of the program. The term +"use" as used herein includes, but is not limited to, the use, +modification, copying and distribution of the program and the +production of secondary products from the program. + +In the case where the program, whether in its original form or +modified, was distributed or delivered to or received by a user from +any person, organization or entity other than ICOT, unless it makes or +grants independently of ICOT any specific warranty to the user in +writing, such person, organization or entity, will also be exempted +from and not be held liable to the user for any such damages as noted +above as far as the program is concerned. + +The Derby build relies on a jar file supplied by the JSON Simple +project, hosted at https://code.google.com/p/json-simple/. +The JSON simple jar file is licensed under the Apache 2.0 License. + +Hive CLI +Copyright 2016 The Apache Software Foundation + +Hive JDBC +Copyright 2016 The Apache Software Foundation + + +Chill is a set of Scala extensions for Kryo. +Copyright 2012 Twitter, Inc. + +Third Party Dependencies: + +Kryo 2.17 +BSD 3-Clause License +http://code.google.com/p/kryo + +Commons-Codec 1.7 +Apache Public License 2.0 +http://hadoop.apache.org + + + +Breeze is distributed under an Apache License V2.0 (See LICENSE) + +=============================================================================== + +Proximal algorithms outlined in Proximal.scala (package breeze.optimize.proximal) +are based on https://github.com/cvxgrp/proximal (see LICENSE for details) and distributed with +Copyright (c) 2014 by Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +QuadraticMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2014, Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +NonlinearMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2015, Debasish Das (Verizon), all rights reserved. + + + ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the distribution of jets3t. == + ========================================================================= + + This product includes software developed by: + + The Apache Software Foundation (http://www.apache.org/). + + The ExoLab Project (http://www.exolab.org/) + + Sun Microsystems (http://www.sun.com/) + + Codehaus (http://castor.codehaus.org) + + Tatu Saloranta (http://wiki.fasterxml.com/TatuSaloranta) + + + +stream-lib +Copyright 2016 AddThis + +This product includes software developed by AddThis. + +This product also includes code adapted from: + +Apache Solr (http://lucene.apache.org/solr/) +Copyright 2014 The Apache Software Foundation + +Apache Mahout (http://mahout.apache.org/) +Copyright 2014 The Apache Software Foundation \ No newline at end of file diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 855eb5bf77f16..f52d785e05cdd 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -13,6 +13,7 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), License: Apache License (== 2.0) URL: http://www.apache.org/ http://spark.apache.org/ BugReports: http://spark.apache.org/contributing.html +SystemRequirements: Java (== 8) Depends: R (>= 3.0), methods diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 190c50ea10482..adfd3871f3426 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,16 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_distinct", + "array_join", + "array_max", + "array_min", + "array_position", + "array_remove", + "array_repeat", + "array_sort", + "arrays_overlap", + "arrays_zip", "asc", "ascii", "asin", @@ -245,6 +255,7 @@ exportMethods("%<=>%", "decode", "dense_rank", "desc", + "element_at", "encode", "endsWith", "exp", @@ -254,6 +265,7 @@ exportMethods("%<=>%", "expr", "factorial", "first", + "flatten", "floor", "format_number", "format_string", @@ -296,6 +308,8 @@ exportMethods("%<=>%", "lower", "lpad", "ltrim", + "map_entries", + "map_from_arrays", "map_keys", "map_values", "max", @@ -346,6 +360,7 @@ exportMethods("%<=>%", "sinh", "size", "skewness", + "slice", "sort_array", "soundex", "spark_partition_id", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a1c9495b0795e..70eb7a874b75c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2297,6 +2297,8 @@ setMethod("rename", setClassUnion("characterOrColumn", c("character", "Column")) +setClassUnion("numericOrColumn", c("numeric", "Column")) + #' Arrange Rows by Variables #' #' Sort a SparkDataFrame by the specified column(s). diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 9d82814211bc5..660f0864403e0 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -19,7 +19,7 @@ # Creates a SparkR client connection object # if one doesn't already exist -connectBackend <- function(hostname, port, timeout) { +connectBackend <- function(hostname, port, timeout, authSecret) { if (exists(".sparkRcon", envir = .sparkREnv)) { if (isOpen(.sparkREnv[[".sparkRCon"]])) { cat("SparkRBackend client connection already exists\n") @@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) { con <- socketConnection(host = hostname, port = port, server = FALSE, blocking = TRUE, open = "wb", timeout = timeout) - + doServerAuth(con, authSecret) assign(".sparkRCon", con, envir = .sparkREnv) con } @@ -60,6 +60,47 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack combinedArgs } +checkJavaVersion <- function() { + javaBin <- "java" + javaHome <- Sys.getenv("JAVA_HOME") + javaReqs <- utils::packageDescription(utils::packageName(), fields = c("SystemRequirements")) + sparkJavaVersion <- as.numeric(tail(strsplit(javaReqs, "[(=)]")[[1]], n = 1L)) + if (javaHome != "") { + javaBin <- file.path(javaHome, "bin", javaBin) + } + + # If java is missing from PATH, we get an error in Unix and a warning in Windows + javaVersionOut <- tryCatch( + if (is_windows()) { + # See SPARK-24535 + system2(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE) + } else { + launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE) + }, + error = function(e) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", e) + }, + warning = function(w) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", w) + }) + javaVersionFilter <- Filter( + function(x) { + grepl(" version", x) + }, javaVersionOut) + + javaVersionStr <- strsplit(javaVersionFilter[[1]], "[\"]")[[1L]][2] + # javaVersionStr is of the form 1.8.0_92. + # Extract 8 from it to compare to sparkJavaVersion + javaVersionNum <- as.integer(strsplit(javaVersionStr, "[.]")[[1L]][2]) + if (javaVersionNum != sparkJavaVersion) { + stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", + javaVersionStr)) + } + return(javaVersionNum) +} + launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { sparkSubmitBinName <- determineSparkSubmitBin() if (sparkHome != "") { @@ -67,6 +108,7 @@ launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { } else { sparkSubmitBin <- sparkSubmitBinName } + combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") invisible(launchScript(sparkSubmitBin, combinedArgs)) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 8ec727dd042bc..3e996a5ba26fc 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -305,6 +305,8 @@ setCheckpointDirSC <- function(sc, dirName) { #' Currently directories are only supported for Hadoop-supported filesystems. #' Refer Hadoop-supported filesystems at \url{https://wiki.apache.org/hadoop/HCFS}. #' +#' Note: A path can be added only once. Subsequent additions of the same path are ignored. +#' #' @rdname spark.addFile #' @param path The path of the file to be added #' @param recursive Whether to add files recursively from the path. Default is FALSE. diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index a90f7d381026b..cb03f1667629f 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -60,14 +60,18 @@ readTypedObject <- function(con, type) { stop(paste("Unsupported type for deserialization", type))) } -readString <- function(con) { - stringLen <- readInt(con) - raw <- readBin(con, raw(), stringLen, endian = "big") +readStringData <- function(con, len) { + raw <- readBin(con, raw(), len, endian = "big") string <- rawToChar(raw) Encoding(string) <- "UTF-8" string } +readString <- function(con) { + stringLen <- readInt(con) + readStringData(con, stringLen) +} + readInt <- function(con) { readBin(con, integer(), n = 1, endian = "big") } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index a527426b19674..2929a00330c62 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,9 +189,17 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param y Column to compute on. +#' @param value A value to compute on. +#' \itemize{ +#' \item \code{array_contains}: a value to be checked if contained in the column. +#' \item \code{array_position}: a value to locate in the given array. +#' \item \code{array_remove}: a value to remove in the given array. +#' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same -#' options as the JSON data source. +#' options as the JSON data source. In \code{arrays_zip}, this contains additional +#' Columns of arrays to be merged. #' @name column_collection_functions #' @rdname column_collection_functions #' @family collection functions @@ -201,14 +209,24 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) +#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1), array_remove(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) +#' head(select(tmp, slice(tmp$v1, 2L, 2L))) #' head(select(tmp, sort_array(tmp$v1))) #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) -#' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3)))} +#' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3))) +#' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) +#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) +#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) +#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5), map_from_arrays(tmp4$v4, tmp4$v5))) +#' head(select(tmp, concat(df$mpg, df$cyl, df$hp))) +#' tmp5 <- mutate(df, v6 = create_array(df$model, df$model)) +#' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))} NULL #' Window functions for Column operations @@ -796,6 +814,8 @@ setMethod("factorial", #' #' The function by default returns the first values it sees. It will return the first non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' Note: the function is non-deterministic because its results depends on order of rows which +#' may be non-deterministic after a shuffle. #' #' @param na.rm a logical value indicating whether NA values should be stripped #' before the computation proceeds. @@ -939,6 +959,8 @@ setMethod("kurtosis", #' #' The function by default returns the last values it sees. It will return the last non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' Note: the function is non-deterministic because its results depends on order of rows which +#' may be non-deterministic after a shuffle. #' #' @param x column to compute on. #' @param na.rm a logical value indicating whether NA values should be stripped @@ -1192,6 +1214,7 @@ setMethod("minute", #' 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. #' This is equivalent to the MONOTONICALLY_INCREASING_ID function in SQL. #' The method should be used with no argument. +#' Note: the function is non-deterministic because its result depends on partition IDs. #' #' @rdname column_nonaggregate_functions #' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method @@ -1245,9 +1268,9 @@ setMethod("quarter", }) #' @details -#' \code{reverse}: Reverses the string column and returns it as a new string column. +#' \code{reverse}: Returns a reversed string or an array with reverse order of elements. #' -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @aliases reverse reverse,Column-method #' @note reverse since 1.5.0 setMethod("reverse", @@ -1898,6 +1921,7 @@ setMethod("atan2", signature(y = "Column"), #' @details #' \code{datediff}: Returns the number of days from \code{y} to \code{x}. +#' If \code{y} is later than \code{x} then the result is positive. #' #' @rdname column_datetime_diff_functions #' @aliases datediff datediff,Column-method @@ -1958,6 +1982,9 @@ setMethod("levenshtein", signature(y = "Column"), #' @details #' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x} +#' are on the same day of month, or both are the last day of month, time of day will be ignored. +#' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits. #' #' @rdname column_datetime_diff_functions #' @aliases months_between months_between,Column-method @@ -2036,20 +2063,10 @@ setMethod("countDistinct", #' @details #' \code{concat}: Concatenates multiple input columns together into a single column. -#' If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. +#' The function works with strings, binary and compatible array columns. #' -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @aliases concat concat,Column-method -#' @examples -#' -#' \dontrun{ -#' # concatenate strings -#' tmp <- mutate(df, s1 = concat(df$Class, df$Sex), -#' s2 = concat(df$Class, df$Sex, df$Age), -#' s3 = concat(df$Class, df$Sex, df$Age, df$Class), -#' s4 = concat_ws("_", df$Class, df$Sex), -#' s5 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) -#' head(tmp)} #' @note concat since 1.5.0 setMethod("concat", signature(x = "Column"), @@ -2390,6 +2407,13 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @rdname column_string_functions #' @aliases concat_ws concat_ws,character,Column-method +#' @examples +#' +#' \dontrun{ +#' # concatenate strings +#' tmp <- mutate(df, s1 = concat_ws("_", df$Class, df$Sex), +#' s2 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) +#' head(tmp)} #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -2575,6 +2599,7 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @details #' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) #' samples from U[0.0, 1.0]. +#' Note: the function is non-deterministic in general case. #' #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. @@ -2603,6 +2628,7 @@ setMethod("rand", signature(seed = "numeric"), #' @details #' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples #' from the standard normal distribution. +#' Note: the function is non-deterministic in general case. #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method @@ -2975,7 +3001,6 @@ setMethod("row_number", #' \code{array_contains}: Returns null if the array is null, true if the array contains #' the value, and false otherwise. #' -#' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method #' @note array_contains since 1.6.0 @@ -2986,6 +3011,204 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_distinct}: Removes duplicate values from the array. +#' +#' @rdname column_collection_functions +#' @aliases array_distinct array_distinct,Column-method +#' @note array_distinct since 2.4.0 +setMethod("array_distinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_distinct", x@jc) + column(jc) + }) + +#' @details +#' \code{array_join}: Concatenates the elements of column using the delimiter. +#' Null values are replaced with nullReplacement if set, otherwise they are ignored. +#' +#' @param delimiter a character string that is used to concatenate the elements of column. +#' @param nullReplacement an optional character string that is used to replace the Null values. +#' @rdname column_collection_functions +#' @aliases array_join array_join,Column-method +#' @note array_join since 2.4.0 +setMethod("array_join", + signature(x = "Column", delimiter = "character"), + function(x, delimiter, nullReplacement = NULL) { + jc <- if (is.null(nullReplacement)) { + callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter) + } else { + callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter, + as.character(nullReplacement)) + } + column(jc) + }) + +#' @details +#' \code{array_max}: Returns the maximum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_max array_max,Column-method +#' @note array_max since 2.4.0 +setMethod("array_max", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_max", x@jc) + column(jc) + }) + +#' @details +#' \code{array_min}: Returns the minimum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_min array_min,Column-method +#' @note array_min since 2.4.0 +setMethod("array_min", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_min", x@jc) + column(jc) + }) + +#' @details +#' \code{array_position}: Locates the position of the first occurrence of the given value +#' in the given array. Returns NA if either of the arguments are NA. +#' Note: The position is not zero based, but 1 based index. Returns 0 if the given +#' value could not be found in the array. +#' +#' @rdname column_collection_functions +#' @aliases array_position array_position,Column-method +#' @note array_position since 2.4.0 +setMethod("array_position", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_position", x@jc, value) + column(jc) + }) + +#' @details +#' \code{array_remove}: Removes all elements that equal to element from the given array. +#' +#' @rdname column_collection_functions +#' @aliases array_remove array_remove,Column-method +#' @note array_remove since 2.4.0 +setMethod("array_remove", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_remove", x@jc, value) + column(jc) + }) + +#' @details +#' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times +#' given by \code{count}. +#' +#' @param count a Column or constant determining the number of repetitions. +#' @rdname column_collection_functions +#' @aliases array_repeat array_repeat,Column,numericOrColumn-method +#' @note array_repeat since 2.4.0 +setMethod("array_repeat", + signature(x = "Column", count = "numericOrColumn"), + function(x, count) { + if (class(count) == "Column") { + count <- count@jc + } else { + count <- as.integer(count) + } + jc <- callJStatic("org.apache.spark.sql.functions", "array_repeat", x@jc, count) + column(jc) + }) + +#' @details +#' \code{array_sort}: Sorts the input array in ascending order. The elements of the input array +#' must be orderable. NA elements will be placed at the end of the returned array. +#' +#' @rdname column_collection_functions +#' @aliases array_sort array_sort,Column-method +#' @note array_sort since 2.4.0 +setMethod("array_sort", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc) + column(jc) + }) + +#' @details +#' \code{arrays_overlap}: Returns true if the input arrays have at least one non-null element in +#' common. If not and both arrays are non-empty and any of them contains a null, it returns null. +#' It returns false otherwise. +#' +#' @rdname column_collection_functions +#' @aliases arrays_overlap arrays_overlap,Column-method +#' @note arrays_overlap since 2.4.0 +setMethod("arrays_overlap", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_overlap", x@jc, y@jc) + column(jc) + }) + +#' @details +#' \code{arrays_zip}: Returns a merged array of structs in which the N-th struct contains all N-th +#' values of input arrays. +#' +#' @rdname column_collection_functions +#' @aliases arrays_zip arrays_zip,Column-method +#' @note arrays_zip since 2.4.0 +setMethod("arrays_zip", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function(arg) { + stopifnot(class(arg) == "Column") + arg@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_zip", jcols) + column(jc) + }) + +#' @details +#' \code{flatten}: Creates a single array from an array of arrays. +#' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. +#' +#' @rdname column_collection_functions +#' @aliases flatten flatten,Column-method +#' @note flatten since 2.4.0 +setMethod("flatten", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "flatten", x@jc) + column(jc) + }) + +#' @details +#' \code{map_entries}: Returns an unordered array of all entries in the given map. +#' +#' @rdname column_collection_functions +#' @aliases map_entries map_entries,Column-method +#' @note map_entries since 2.4.0 +setMethod("map_entries", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_entries", x@jc) + column(jc) + }) + +#' @details +#' \code{map_from_arrays}: Creates a new map column. The array in the first column is used for +#' keys. The array in the second column is used for values. All elements in the array for key +#' should not be null. +#' +#' @rdname column_collection_functions +#' @aliases map_from_arrays map_from_arrays,Column-method +#' @note map_from_arrays since 2.4.0 +setMethod("map_from_arrays", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_from_arrays", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' @@ -3012,6 +3235,22 @@ setMethod("map_values", column(jc) }) +#' @details +#' \code{element_at}: Returns element of array at given index in \code{extraction} if +#' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. +#' Note: The position is not zero based, but 1 based index. +#' +#' @param extraction index to check for in array or key to check for in map +#' @rdname column_collection_functions +#' @aliases element_at element_at,Column-method +#' @note element_at since 2.4.0 +setMethod("element_at", + signature(x = "Column", extraction = "ANY"), + function(x, extraction) { + jc <- callJStatic("org.apache.spark.sql.functions", "element_at", x@jc, extraction) + column(jc) + }) + #' @details #' \code{explode}: Creates a new row for each element in the given array or map column. #' @@ -3039,8 +3278,25 @@ setMethod("size", }) #' @details -#' \code{sort_array}: Sorts the input array in ascending or descending order according -#' to the natural ordering of the array elements. +#' \code{slice}: Returns an array containing all the elements in x from the index start +#' (or starting from the end if start is negative) with the specified length. +#' +#' @rdname column_collection_functions +#' @param start an index indicating the first element occurring in the result. +#' @param length a number of consecutive elements chosen to the result. +#' @aliases slice slice,Column-method +#' @note slice since 2.4.0 +setMethod("slice", + signature(x = "Column"), + function(x, start, length) { + jc <- callJStatic("org.apache.spark.sql.functions", "slice", x@jc, start, length) + column(jc) + }) + +#' @details +#' \code{sort_array}: Sorts the input array in ascending or descending order according to +#' the natural ordering of the array elements. NA elements will be placed at the beginning of +#' the returned array in ascending order or at the end of the returned array in descending order. #' #' @rdname column_collection_functions #' @param asc a logical flag indicating the sorting order. @@ -3109,6 +3365,8 @@ setMethod("create_map", #' @details #' \code{collect_list}: Creates a list of objects with duplicates. +#' Note: the function is non-deterministic because the order of collected results depends +#' on order of rows which may be non-deterministic after a shuffle. #' #' @rdname column_aggregate_functions #' @aliases collect_list collect_list,Column-method @@ -3128,6 +3386,8 @@ setMethod("collect_list", #' @details #' \code{collect_set}: Creates a list of objects with duplicate elements eliminated. +#' Note: the function is non-deterministic because the order of collected results depends +#' on order of rows which may be non-deterministic after a shuffle. #' #' @rdname column_aggregate_functions #' @aliases collect_set collect_set,Column-method diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 974beff1a3d76..4a7210bf1b902 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -624,7 +624,7 @@ setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary setGeneric("summary", function(object, ...) { standardGeneric("summary") }) -setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) +setGeneric("toJSON", function(x, ...) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) @@ -757,6 +757,46 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_distinct", function(x) { standardGeneric("array_distinct") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_max", function(x) { standardGeneric("array_max") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_min", function(x) { standardGeneric("array_min") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_remove", function(x, value) { standardGeneric("array_remove") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_zip", function(x, ...) { standardGeneric("arrays_zip") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -801,7 +841,7 @@ setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column setGeneric("column", function(x) { standardGeneric("column") }) -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) @@ -886,6 +926,10 @@ setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("element_at", function(x, extraction) { standardGeneric("element_at") }) + #' @rdname column_string_functions #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) @@ -902,6 +946,10 @@ setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("flatten", function(x) { standardGeneric("flatten") }) + #' @rdname column_datetime_diff_functions #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) @@ -1010,6 +1058,14 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @name NULL setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_entries", function(x) { standardGeneric("map_entries") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_from_arrays", function(x, y) { standardGeneric("map_from_arrays") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) @@ -1110,7 +1166,7 @@ setGeneric("regexp_replace", #' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) @@ -1170,6 +1226,10 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("slice", function(x, start, length) { standardGeneric("slice") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index a480ac606f10d..d3a9cbae7d808 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -158,11 +158,16 @@ sparkR.sparkContext <- function( " please use the --packages commandline instead", sep = ",")) } backendPort <- existingPort + authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET") + if (nchar(authSecret) == 0) { + stop("Auth secret not provided in environment.") + } } else { path <- tempfile(pattern = "backend_port") submitOps <- getClientModeSparkSubmitOpts( Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), sparkEnvirMap) + invisible(checkJavaVersion()) launchBackend( args = path, sparkHome = sparkHome, @@ -186,16 +191,27 @@ sparkR.sparkContext <- function( monitorPort <- readInt(f) rLibPath <- readString(f) connectionTimeout <- readInt(f) + + # Don't use readString() so that we can provide a useful + # error message if the R and Java versions are mismatched. + authSecretLen <- readInt(f) + if (length(authSecretLen) == 0 || authSecretLen == 0) { + stop("Unexpected EOF in JVM connection data. Mismatched versions?") + } + authSecret <- readStringData(f, authSecretLen) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || length(monitorPort) == 0 || monitorPort == 0 || - length(rLibPath) != 1) { + length(rLibPath) != 1 || length(authSecret) == 0) { stop("JVM failed to launch") } - assign(".monitorConn", - socketConnection(port = monitorPort, timeout = connectionTimeout), - envir = .sparkREnv) + + monitorConn <- socketConnection(port = monitorPort, blocking = TRUE, + timeout = connectionTimeout, open = "wb") + doServerAuth(monitorConn, authSecret) + + assign(".monitorConn", monitorConn, envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) if (rLibPath != "") { assign(".libPath", rLibPath, envir = .sparkREnv) @@ -205,7 +221,7 @@ sparkR.sparkContext <- function( .sparkREnv$backendPort <- backendPort tryCatch({ - connectBackend("localhost", backendPort, timeout = connectionTimeout) + connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret) }, error = function(err) { stop("Failed to connect JVM\n") @@ -687,3 +703,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) { NULL } } + +# Utility function for sending auth data over a socket and checking the server's reply. +doServerAuth <- function(con, authSecret) { + if (nchar(authSecret) == 0) { + stop("Auth secret not provided.") + } + writeString(con, authSecret) + flush(con) + reply <- readString(con) + if (reply != "ok") { + close(con) + stop("Unexpected reply from server.") + } +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index f1b5ecaa017df..c3501977e64bc 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -746,7 +746,7 @@ varargsToJProperties <- function(...) { props } -launchScript <- function(script, combinedArgs, wait = FALSE) { +launchScript <- function(script, combinedArgs, wait = FALSE, stdout = "", stderr = "") { if (.Platform$OS.type == "windows") { scriptWithArgs <- paste(script, combinedArgs, sep = " ") # on Windows, intern = F seems to mean output to the console. (documentation on this is missing) @@ -756,7 +756,7 @@ launchScript <- function(script, combinedArgs, wait = FALSE) { # stdout = F means discard output # stdout = "" means to its console (default) # Note that the console of this child process might not be the same as the running R process. - system2(script, combinedArgs, stdout = "", wait = wait) + system2(script, combinedArgs, stdout = stdout, wait = wait, stderr = stderr) } } diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R index 823d26f12feee..243f5f0298284 100644 --- a/R/pkg/inst/tests/testthat/test_basic.R +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -18,6 +18,10 @@ context("basic tests for CRAN") test_that("create DataFrame from list or data.frame", { + tryCatch( checkJavaVersion(), + error = function(e) { skip("error on Java check") }, + warning = function(e) { skip("warning on Java check") } ) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) @@ -50,6 +54,10 @@ test_that("create DataFrame from list or data.frame", { }) test_that("spark.glm and predict", { + tryCatch( checkJavaVersion(), + error = function(e) { skip("error on Java check") }, + warning = function(e) { skip("warning on Java check") } ) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 2e31dc5f728cd..fb9db63b07cd0 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( - port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) + port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout) + +SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET")) # Waits indefinitely for a socket connecion by default. selectTimeout <- NULL diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 00789d815bba8..ba458d2b9ddfb 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( - port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout) + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) +SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET")) + outputCon <- socketConnection( port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) +SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET")) # read the index of the current partition inside the RDD partition <- SparkR:::readInt(inputCon) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 7105469ffc242..adcbbff823a2d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,24 +1479,125 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains() and sort_array() + # Test array_contains(), array_max(), array_min(), array_position(), element_at() and reverse() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) + result <- collect(select(df, array_max(df[[1]])))[[1]] + expect_equal(result, c(3, 6)) + + result <- collect(select(df, array_min(df[[1]])))[[1]] + expect_equal(result, c(1, 4)) + + result <- collect(select(df, array_position(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 0)) + + result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 6)) + + result <- collect(select(df, reverse(df[[1]])))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(4L, 5L, 6L))) + + df2 <- createDataFrame(list(list("abc"))) + result <- collect(select(df2, reverse(df2[[1]])))[[1]] + expect_equal(result, "cba") + + # Test array_distinct() and array_remove() + df <- createDataFrame(list(list(list(1L, 2L, 3L, 1L, 2L)), list(list(6L, 5L, 5L, 4L, 6L)))) + result <- collect(select(df, array_distinct(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(6L, 5L, 4L))) + + result <- collect(select(df, array_remove(df[[1]], 2L)))[[1]] + expect_equal(result, list(list(1L, 3L, 1L), list(6L, 5L, 5L, 4L, 6L))) + + # Test arrays_zip() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 4L))), schema = c("c1", "c2")) + result <- collect(select(df, arrays_zip(df[[1]], df[[2]])))[[1]] + expected_entries <- list(listToStruct(list(c1 = 1L, c2 = 3L)), + listToStruct(list(c1 = 2L, c2 = 4L))) + expect_equal(result, list(expected_entries)) + + # Test map_from_arrays() + df <- createDataFrame(list(list(list("x", "y"), list(1, 2))), schema = c("k", "v")) + result <- collect(select(df, map_from_arrays(df$k, df$v)))[[1]] + expected_entries <- list(as.environment(list(x = 1, y = 2))) + expect_equal(result, expected_entries) + + # Test array_repeat() + df <- createDataFrame(list(list("a", 3L), list("b", 2L))) + result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list("a", "a", "a"), list("b", "b"))) + + result <- collect(select(df, array_repeat(df[[1]], 2L)))[[1]] + expect_equal(result, list(list("a", "a"), list("b", "b"))) + + # Test arrays_overlap() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)), + list(list(1L, 2L), list(3L, 4L)), + list(list(1L, NA), list(3L, 4L)))) + result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]] + expect_equal(result, c(TRUE, FALSE, NA)) + + # Test array_join() + df <- createDataFrame(list(list(list("Hello", "World!")))) + result <- collect(select(df, array_join(df[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + df2 <- createDataFrame(list(list(list("Hello", NA, "World!")))) + result <- collect(select(df2, array_join(df2[[1]], "#", "Beautiful")))[[1]] + expect_equal(result, "Hello#Beautiful#World!") + result <- collect(select(df2, array_join(df2[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + df3 <- createDataFrame(list(list(list("Hello", NULL, "World!")))) + result <- collect(select(df3, array_join(df3[[1]], "#", "Beautiful")))[[1]] + expect_equal(result, "Hello#Beautiful#World!") + result <- collect(select(df3, array_join(df3[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + + # Test array_sort() and sort_array() + df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) + + result <- collect(select(df, array_sort(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, NA), list(4L, 5L, 6L, NA, NA))) + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] - expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) + expect_equal(result, list(list(3L, 2L, 1L, NA), list(6L, 5L, 4L, NA, NA))) result <- collect(select(df, sort_array(df[[1]])))[[1]] - expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) - - # Test map_keys() and map_values() + expect_equal(result, list(list(NA, 1L, 2L, 3L), list(NA, NA, 4L, 5L, 6L))) + + # Test slice() + df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(4L, 5L)))) + result <- collect(select(df, slice(df[[1]], 2L, 2L)))[[1]] + expect_equal(result, list(list(2L, 3L), list(5L))) + + # Test concat() + df <- createDataFrame(list(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + list(list(7L, 8L, 9L), list(10L, 11L, 12L)))) + result <- collect(select(df, concat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L, 5L, 6L), list(7L, 8L, 9L, 10L, 11L, 12L))) + + # Test flatten() + df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), + list(list(list(5L, 6L), list(7L, 8L))))) + result <- collect(select(df, flatten(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L))) + + # Test map_entries(), map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) + result <- collect(select(df, map_entries(df$map)))[[1]] + expected_entries <- list(listToStruct(list(key = "x", value = 1)), + listToStruct(list(key = "y", value = 2))) + expect_equal(result, list(expected_entries)) + result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y"))) result <- collect(select(df, map_values(df$map)))[[1]] expect_equal(result, list(list(1, 2))) + result <- collect(select(df, element_at(df$map, "y")))[[1]] + expect_equal(result, 2) + # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) @@ -2188,8 +2289,8 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { expect_equal(count(where(join(df, df2), df$name == df2$name)), 3) # cartesian join expect_error(tryCatch(count(join(df, df2)), error = function(e) { stop(e) }), - paste0(".*(org.apache.spark.sql.AnalysisException: Detected cartesian product for", - " INNER join between logical plans).*")) + paste0(".*(org.apache.spark.sql.AnalysisException: Detected implicit cartesian", + " product for INNER join between logical plans).*")) joined <- crossJoin(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index d4713de7806a1..68a18ab57b28d 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -590,6 +590,7 @@ summary(model) Predict values on training data ```{r} prediction <- predict(model, training) +head(select(prediction, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Logistic Regression @@ -613,6 +614,7 @@ summary(model) Predict values on training data ```{r} fitted <- predict(model, training) +head(select(fitted, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` Multinomial logistic regression against three classes @@ -807,6 +809,7 @@ df <- createDataFrame(t) dtModel <- spark.decisionTree(df, Survived ~ ., type = "classification", maxDepth = 2) summary(dtModel) predictions <- predict(dtModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Gradient-Boosted Trees @@ -822,6 +825,7 @@ df <- createDataFrame(t) gbtModel <- spark.gbt(df, Survived ~ ., type = "classification", maxDepth = 2, maxIter = 2) summary(gbtModel) predictions <- predict(gbtModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Random Forest @@ -837,6 +841,7 @@ df <- createDataFrame(t) rfModel <- spark.randomForest(df, Survived ~ ., type = "classification", maxDepth = 2, numTrees = 2) summary(rfModel) predictions <- predict(rfModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Bisecting k-Means diff --git a/README.md b/README.md index 1e521a7e7b178..531d330234062 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,8 @@ can be run using: Please see the guidance on how to [run tests for a module, or individual tests](http://spark.apache.org/developer-tools.html#individual-tests). +There is also a Kubernetes integration test, see resource-managers/kubernetes/integration-tests/README.md + ## A Note About Hadoop Versions Spark uses the Hadoop core library to talk to HDFS and other Hadoop-supported diff --git a/assembly/pom.xml b/assembly/pom.xml index a207dae5a74ff..9608c96fd5369 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -254,6 +254,14 @@ spark-hadoop-cloud_${scala.binary.version} ${project.version} + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index f090240065bf1..a3f1bcffaea57 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -63,16 +63,25 @@ function build { if [ ! -d "$IMG_PATH" ]; then error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi - - local DOCKERFILE=${DOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} - - docker build "${BUILD_ARGS[@]}" \ + local BINDING_BUILD_ARGS=( + --build-arg + base_img=$(image_ref spark) + ) + local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} + + docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ - -f "$DOCKERFILE" . + -f "$BASEDOCKERFILE" . + + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-py) \ + -f "$PYDOCKERFILE" . } function push { docker push "$(image_ref spark)" + docker push "$(image_ref spark-py)" } function usage { @@ -86,10 +95,12 @@ Commands: push Push a pre-built image to a registry. Requires a repository address to be provided. Options: - -f file Dockerfile to build. By default builds the Dockerfile shipped with Spark. + -f file Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark. + -p file Dockerfile with Python baked in. By default builds the Dockerfile shipped with Spark. -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. + -n Build docker image with --no-cache Using minikube when building images will do so directly into minikube's Docker daemon. There is no need to push the images into minikube in that case, they'll be automatically @@ -116,14 +127,18 @@ fi REPO= TAG= -DOCKERFILE= -while getopts f:mr:t: option +BASEDOCKERFILE= +PYDOCKERFILE= +NOCACHEARG= +while getopts f:mr:t:n option do case "${option}" in - f) DOCKERFILE=${OPTARG};; + f) BASEDOCKERFILE=${OPTARG};; + p) PYDOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; + n) NOCACHEARG="--no-cache";; m) if ! which minikube 1>/dev/null; then error "Cannot find minikube." diff --git a/bin/pyspark b/bin/pyspark index dd286277c1fc1..5d5affb1f97c3 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option -# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython +# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython # to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver # (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython # and executor Python executables. # Fail noisily if removed options are set if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then - echo "Error in pyspark startup:" + echo "Error in pyspark startup:" echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead." exit 1 fi @@ -57,7 +57,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 663670f2fddaf..15fa910c277b3 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/build/mvn b/build/mvn index efa4f9364ea52..ae4276dbc7e32 100755 --- a/build/mvn +++ b/build/mvn @@ -93,7 +93,7 @@ install_mvn() { install_zinc() { local zinc_path="zinc-0.3.15/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} install_app \ "${TYPESAFE_MIRROR}/zinc/0.3.15" \ @@ -109,7 +109,7 @@ install_scala() { # determine the Scala version used in Spark local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} install_app \ "${TYPESAFE_MIRROR}/scala/${scala_version}" \ @@ -154,4 +154,4 @@ export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} echo "Using \`mvn\` from path: $MVN_BIN" 1>&2 # Last, call the `mvn` command as usual -${MVN_BIN} -DzincPort=${ZINC_PORT} "$@" +"${MVN_BIN}" -DzincPort=${ZINC_PORT} "$@" diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java new file mode 100644 index 0000000000000..bd173b653e33e --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +public interface StreamCallbackWithID extends StreamCallback { + String getID(); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index b0e85bae7c309..f3eb744ff7345 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -22,22 +22,24 @@ import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.server.MessageHandler; import org.apache.spark.network.util.TransportFrameDecoder; /** * An interceptor that is registered with the frame decoder to feed stream data to a * callback. */ -class StreamInterceptor implements TransportFrameDecoder.Interceptor { +public class StreamInterceptor implements TransportFrameDecoder.Interceptor { - private final TransportResponseHandler handler; + private final MessageHandler handler; private final String streamId; private final long byteCount; private final StreamCallback callback; private long bytesRead; - StreamInterceptor( - TransportResponseHandler handler, + public StreamInterceptor( + MessageHandler handler, String streamId, long byteCount, StreamCallback callback) { @@ -50,16 +52,24 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { @Override public void exceptionCaught(Throwable cause) throws Exception { - handler.deactivateStream(); + deactivateStream(); callback.onFailure(streamId, cause); } @Override public void channelInactive() throws Exception { - handler.deactivateStream(); + deactivateStream(); callback.onFailure(streamId, new ClosedChannelException()); } + private void deactivateStream() { + if (handler instanceof TransportResponseHandler) { + // we only have to do this for TransportResponseHandler as it exposes numOutstandingFetches + // (there is no extra cleanup that needs to happen) + ((TransportResponseHandler) handler).deactivateStream(); + } + } + @Override public boolean handle(ByteBuf buf) throws Exception { int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead); @@ -72,10 +82,10 @@ public boolean handle(ByteBuf buf) throws Exception { RuntimeException re = new IllegalStateException(String.format( "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); callback.onFailure(streamId, re); - handler.deactivateStream(); + deactivateStream(); throw re; } else if (bytesRead == byteCount) { - handler.deactivateStream(); + deactivateStream(); callback.onComplete(streamId); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8f354ad78bbaa..325225dc0ea2c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -32,15 +32,15 @@ import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; import io.netty.channel.Channel; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.StreamRequest; +import org.apache.spark.network.protocol.*; + import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** @@ -133,34 +133,21 @@ public void fetchChunk( long streamId, int chunkIndex, ChunkReceivedCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); } StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); - handler.addFetchRequest(streamChunkId, callback); - - channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", streamChunkId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); + StdChannelListener listener = new StdChannelListener(streamChunkId) { + @Override + void handleFailure(String errorMsg, Throwable cause) { handler.removeFetchRequest(streamChunkId); - channel.close(); - try { - callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } + callback.onFailure(chunkIndex, new IOException(errorMsg, cause)); } - }); + }; + handler.addFetchRequest(streamChunkId, callback); + + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener); } /** @@ -170,7 +157,12 @@ public void fetchChunk( * @param callback Object to call with the stream data. */ public void stream(String streamId, StreamCallback callback) { - long startTime = System.currentTimeMillis(); + StdChannelListener listener = new StdChannelListener(streamId) { + @Override + void handleFailure(String errorMsg, Throwable cause) throws Exception { + callback.onFailure(streamId, new IOException(errorMsg, cause)); + } + }; if (logger.isDebugEnabled()) { logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel)); } @@ -180,25 +172,7 @@ public void stream(String streamId, StreamCallback callback) { // when responses arrive. synchronized (this) { handler.addStreamCallback(streamId, callback); - channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request for {} to {} took {} ms", streamId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - channel.close(); - try { - callback.onFailure(streamId, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener); } } @@ -211,35 +185,44 @@ public void stream(String streamId, StreamCallback callback) { * @return The RPC's id. */ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isTraceEnabled()) { logger.trace("Sending RPC to {}", getRemoteAddress(channel)); } - long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); + long requestId = requestId(); handler.addRpcRequest(requestId, callback); + RpcChannelListener listener = new RpcChannelListener(requestId, callback); channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))) - .addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", requestId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - handler.removeRpcRequest(requestId); - channel.close(); - try { - callback.onFailure(new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + .addListener(listener); + + return requestId; + } + + /** + * Send data to the remote end as a stream. This differs from stream() in that this is a request + * to *send* data to the remote end, not to receive it from the remote. + * + * @param meta meta data associated with the stream, which will be read completely on the + * receiving end before the stream itself. + * @param data this will be streamed to the remote end to allow for transferring large amounts + * of data without reading into memory. + * @param callback handles the reply -- onSuccess will only be called when both message and data + * are received successfully. + */ + public long uploadStream( + ManagedBuffer meta, + ManagedBuffer data, + RpcResponseCallback callback) { + if (logger.isTraceEnabled()) { + logger.trace("Sending RPC to {}", getRemoteAddress(channel)); + } + + long requestId = requestId(); + handler.addRpcRequest(requestId, callback); + + RpcChannelListener listener = new RpcChannelListener(requestId, callback); + channel.writeAndFlush(new UploadStream(requestId, meta, data)).addListener(listener); return requestId; } @@ -319,4 +302,60 @@ public String toString() { .add("isActive", isActive()) .toString(); } + + private static long requestId() { + return Math.abs(UUID.randomUUID().getLeastSignificantBits()); + } + + private class StdChannelListener + implements GenericFutureListener> { + final long startTime; + final Object requestId; + + StdChannelListener(Object requestId) { + this.startTime = System.currentTimeMillis(); + this.requestId = requestId; + } + + @Override + public void operationComplete(Future future) throws Exception { + if (future.isSuccess()) { + if (logger.isTraceEnabled()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", requestId, + getRemoteAddress(channel), timeTaken); + } + } else { + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, + getRemoteAddress(channel), future.cause()); + logger.error(errorMsg, future.cause()); + channel.close(); + try { + handleFailure(errorMsg, future.cause()); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + } + + void handleFailure(String errorMsg, Throwable cause) throws Exception {} + } + + private class RpcChannelListener extends StdChannelListener { + final long rpcRequestId; + final RpcResponseCallback callback; + + RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) { + super("RPC " + rpcRequestId); + this.rpcRequestId = rpcRequestId; + this.callback = callback; + } + + @Override + void handleFailure(String errorMsg, Throwable cause) { + handler.removeRpcRequest(rpcRequestId); + callback.onFailure(new IOException(errorMsg, cause)); + } + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 8a6e3858081bf..fb44dbbb0953b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -29,6 +29,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.sasl.SaslRpcHandler; @@ -149,6 +150,14 @@ public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + return delegate.receiveStream(client, message, callback); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 434935a8ef2ad..0ccd70c03aba8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -37,7 +37,7 @@ enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), User(-1); + OneWayMessage(9), UploadStream(10), User(-1); private final byte id; @@ -65,6 +65,7 @@ public static Type decode(ByteBuf buf) { case 7: return StreamResponse; case 8: return StreamFailure; case 9: return OneWayMessage; + case 10: return UploadStream; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 39a7495828a8a..bf80aed0afe10 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -80,6 +80,9 @@ private Message decode(Message.Type msgType, ByteBuf in) { case StreamFailure: return StreamFailure.decode(in); + case UploadStream: + return UploadStream.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index a5337656cbd84..e7b66a6f33a82 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -137,30 +137,15 @@ protected void deallocate() { } private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { - ByteBuffer buffer = buf.nioBuffer(); - int written = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? - target.write(buffer) : writeNioBuffer(target, buffer); + // SPARK-24578: cap the sub-region's size of returned nio buffer to improve the performance + // for the case that the passed-in buffer has too many components. + int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT); + ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); + int written = target.write(buffer); buf.skipBytes(written); return written; } - private int writeNioBuffer( - WritableByteChannel writeCh, - ByteBuffer buf) throws IOException { - int originalLimit = buf.limit(); - int ret = 0; - - try { - int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); - buf.limit(buf.position() + ioSize); - ret = writeCh.write(buf); - } finally { - buf.limit(originalLimit); - } - - return ret; - } - @Override public MessageWithHeader touch(Object o) { super.touch(o); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index 87e212f3e157b..50b811604b84b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -67,7 +67,7 @@ public static StreamResponse decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(byteCount, streamId, body()); + return Objects.hashCode(byteCount, streamId); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java new file mode 100644 index 0000000000000..fa1d26e76b852 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * An RPC with data that is sent outside of the frame, so it can be read as a stream. + */ +public final class UploadStream extends AbstractMessage implements RequestMessage { + /** Used to link an RPC request with its response. */ + public final long requestId; + public final ManagedBuffer meta; + public final long bodyByteCount; + + public UploadStream(long requestId, ManagedBuffer meta, ManagedBuffer body) { + super(body, false); // body is *not* included in the frame + this.requestId = requestId; + this.meta = meta; + bodyByteCount = body.size(); + } + + // this version is called when decoding the bytes on the receiving end. The body is handled + // separately. + private UploadStream(long requestId, ManagedBuffer meta, long bodyByteCount) { + super(null, false); + this.requestId = requestId; + this.meta = meta; + this.bodyByteCount = bodyByteCount; + } + + @Override + public Type type() { return Type.UploadStream; } + + @Override + public int encodedLength() { + // the requestId, meta size, meta and bodyByteCount (body is not included) + return 8 + 4 + ((int) meta.size()) + 8; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + try { + ByteBuffer metaBuf = meta.nioByteBuffer(); + buf.writeInt(metaBuf.remaining()); + buf.writeBytes(metaBuf); + } catch (IOException io) { + throw new RuntimeException(io); + } + buf.writeLong(bodyByteCount); + } + + public static UploadStream decode(ByteBuf buf) { + long requestId = buf.readLong(); + int metaSize = buf.readInt(); + ManagedBuffer meta = new NettyManagedBuffer(buf.readRetainedSlice(metaSize)); + long bodyByteCount = buf.readLong(); + // This is called by the frame decoder, so the data is still null. We need a StreamInterceptor + // to read the data. + return new UploadStream(requestId, meta, bodyByteCount); + } + + @Override + public int hashCode() { + return Long.hashCode(requestId); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UploadStream) { + UploadStream o = (UploadStream) other; + return requestId == o.requestId && super.equals(o); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("body", body()) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 0231428318add..355a3def8cc22 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -28,6 +28,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; @@ -132,6 +133,14 @@ public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + return delegate.receiveStream(client, message, callback); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 8f7554e2e07d5..38569baf82bce 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; /** @@ -36,7 +37,8 @@ public abstract class RpcHandler { * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * - * This method will not be called in parallel for a single TransportClient (i.e., channel). + * Neither this method nor #receiveStream will be called in parallel for a single + * TransportClient (i.e., channel). * * @param client A channel client which enables the handler to make requests back to the sender * of this RPC. This will always be the exact same object for a particular channel. @@ -49,6 +51,36 @@ public abstract void receive( ByteBuffer message, RpcResponseCallback callback); + /** + * Receive a single RPC message which includes data that is to be received as a stream. Any + * exception thrown while in this method will be sent back to the client in string form as a + * standard RPC failure. + * + * Neither this method nor #receive will be called in parallel for a single TransportClient + * (i.e., channel). + * + * An error while reading data from the stream + * ({@link org.apache.spark.network.client.StreamCallback#onData(String, ByteBuffer)}) + * will fail the entire channel. A failure in "post-processing" the stream in + * {@link org.apache.spark.network.client.StreamCallback#onComplete(String)} will result in an + * rpcFailure, but the channel will remain active. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param messageHeader The serialized bytes of the header portion of the RPC. This is in meant + * to be relatively small, and will be buffered entirely in memory, to + * facilitate how the streaming portion should be received. + * @param callback Callback which should be invoked exactly once upon success or failure of the + * RPC. + * @return a StreamCallback for handling the accompanying streaming data + */ + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + /** * Returns the StreamManager which contains the state about which streams are currently being * fetched by a TransportClient. diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e94453578e6b0..e1d7b2dbff60f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import java.io.IOException; import java.net.SocketAddress; import java.nio.ByteBuffer; @@ -28,20 +29,10 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RequestMessage; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcResponse; -import org.apache.spark.network.protocol.StreamFailure; -import org.apache.spark.network.protocol.StreamRequest; -import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.*; +import org.apache.spark.network.util.TransportFrameDecoder; + import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** @@ -52,6 +43,7 @@ * The messages should have been processed by the pipeline setup by {@link TransportServer}. */ public class TransportRequestHandler extends MessageHandler { + private static final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class); /** The Netty channel that this handler is associated with. */ @@ -113,6 +105,8 @@ public void handle(RequestMessage request) { processOneWayMessage((OneWayMessage) request); } else if (request instanceof StreamRequest) { processStreamRequest((StreamRequest) request); + } else if (request instanceof UploadStream) { + processStreamUpload((UploadStream) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } @@ -203,6 +197,79 @@ public void onFailure(Throwable e) { } } + /** + * Handle a request from the client to upload a stream of data. + */ + private void processStreamUpload(final UploadStream req) { + assert (req.body() == null); + try { + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); + } + + @Override + public void onFailure(Throwable e) { + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + }; + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + ByteBuffer meta = req.meta.nioByteBuffer(); + StreamCallbackWithID streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback); + if (streamHandler == null) { + throw new NullPointerException("rpcHandler returned a null streamHandler"); + } + StreamCallbackWithID wrappedCallback = new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + streamHandler.onData(streamId, buf); + } + + @Override + public void onComplete(String streamId) throws IOException { + try { + streamHandler.onComplete(streamId); + callback.onSuccess(ByteBuffer.allocate(0)); + } catch (Exception ex) { + IOException ioExc = new IOException("Failure post-processing complete stream;" + + " failing this rpc and leaving channel active"); + callback.onFailure(ioExc); + streamHandler.onFailure(streamId, ioExc); + } + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + callback.onFailure(new IOException("Destination failed while reading stream", cause)); + streamHandler.onFailure(streamId, cause); + } + + @Override + public String getID() { + return streamHandler.getID(); + } + }; + if (req.bodyByteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, wrappedCallback.getID(), + req.bodyByteCount, wrappedCallback); + frameDecoder.setInterceptor(interceptor); + } else { + wrappedCallback.onComplete(wrappedCallback.getID()); + } + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + // We choose to totally fail the channel, rather than trying to recover as we do in other + // cases. We don't know how many bytes of the stream the client has already sent for the + // stream, it's not worth trying to recover. + channel.pipeline().fireExceptionCaught(e); + } finally { + req.meta.release(); + } + } + private void processOneWayMessage(OneWayMessage req) { try { rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 0719fa7647bcc..60f51125c07fd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -32,6 +32,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import org.apache.commons.lang3.SystemUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -98,6 +99,7 @@ private void init(String hostToBind, int portToBind) { .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) .option(ChannelOption.ALLOCATOR, allocator) + .option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS) .childOption(ChannelOption.ALLOCATOR, allocator); this.metrics = new NettyMemoryMetrics( diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index afc59efaef810..b5497087634ce 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,10 +17,7 @@ package org.apache.spark.network.util; -import java.io.Closeable; -import java.io.EOFException; -import java.io.File; -import java.io.IOException; +import java.io.*; import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; @@ -91,11 +88,24 @@ public static String bytesToString(ByteBuffer b) { * @throws IOException if deletion is unsuccessful */ public static void deleteRecursively(File file) throws IOException { + deleteRecursively(file, null); + } + + /** + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * + * @param file Input file / dir to be deleted + * @param filter A filename filter that make sure only files / dirs with the satisfied filenames + * are deleted. + * @throws IOException if deletion is unsuccessful + */ + public static void deleteRecursively(File file, FilenameFilter filter) throws IOException { if (file == null) { return; } // On Unix systems, use operating system command to run faster // If that does not work out, fallback to the Java IO way - if (SystemUtils.IS_OS_UNIX) { + if (SystemUtils.IS_OS_UNIX && filter == null) { try { deleteRecursivelyUsingUnixNative(file); return; @@ -105,15 +115,17 @@ public static void deleteRecursively(File file) throws IOException { } } - deleteRecursivelyUsingJavaIO(file); + deleteRecursivelyUsingJavaIO(file, filter); } - private static void deleteRecursivelyUsingJavaIO(File file) throws IOException { + private static void deleteRecursivelyUsingJavaIO( + File file, + FilenameFilter filter) throws IOException { if (file.isDirectory() && !isSymlink(file)) { IOException savedIOException = null; - for (File child : listFilesSafely(file)) { + for (File child : listFilesSafely(file, filter)) { try { - deleteRecursively(child); + deleteRecursively(child, filter); } catch (IOException e) { // In case of multiple exceptions, only last one will be thrown savedIOException = e; @@ -124,10 +136,13 @@ private static void deleteRecursivelyUsingJavaIO(File file) throws IOException { } } - boolean deleted = file.delete(); - // Delete can also fail if the file simply did not exist. - if (!deleted && file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath()); + // Delete file only when it's a normal file or an empty directory. + if (file.isFile() || (file.isDirectory() && listFilesSafely(file, null).length == 0)) { + boolean deleted = file.delete(); + // Delete can also fail if the file simply did not exist. + if (!deleted && file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath()); + } } } @@ -157,9 +172,9 @@ private static void deleteRecursivelyUsingUnixNative(File file) throws IOExcepti } } - private static File[] listFilesSafely(File file) throws IOException { + private static File[] listFilesSafely(File file, FilenameFilter filter) throws IOException { if (file.exists()) { - File[] files = file.listFiles(); + File[] files = file.listFiles(filter); if (files == null) { throw new IOException("Failed to list files for dir: " + file); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 8ff737b129641..1f4d75c7e2ec5 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,43 +17,46 @@ package org.apache.spark.network; +import java.io.*; import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Set; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import com.google.common.collect.Sets; +import com.google.common.io.Files; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.*; +import org.apache.spark.network.server.*; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { + static TransportConf conf; static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; static List oneWayMsgs; + static StreamTestHelper testData; + + static ConcurrentHashMap streamCallbacks = + new ConcurrentHashMap<>(); @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + testData = new StreamTestHelper(); rpcHandler = new RpcHandler() { @Override public void receive( @@ -71,6 +74,14 @@ public void receive( } } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + return receiveStreamHelper(JavaUtils.bytesToString(messageHeader)); + } + @Override public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs.add(JavaUtils.bytesToString(message)); @@ -85,10 +96,71 @@ public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs = new ArrayList<>(); } + private static StreamCallbackWithID receiveStreamHelper(String msg) { + try { + if (msg.startsWith("fail/")) { + String[] parts = msg.split("/"); + switch (parts[1]) { + case "exception-ondata": + return new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + throw new IOException("failed to read stream data!"); + } + + @Override + public void onComplete(String streamId) throws IOException { + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + } + + @Override + public String getID() { + return msg; + } + }; + case "exception-oncomplete": + return new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + } + + @Override + public void onComplete(String streamId) throws IOException { + throw new IOException("exception in onComplete"); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + } + + @Override + public String getID() { + return msg; + } + }; + case "null": + return null; + default: + throw new IllegalArgumentException("unexpected msg: " + msg); + } + } else { + VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg); + streamCallbacks.put(msg, streamCallback); + return streamCallback; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + @AfterClass public static void tearDown() { server.close(); clientFactory.close(); + testData.cleanup(); } static class RpcResult { @@ -130,6 +202,59 @@ public void onFailure(Throwable e) { return res; } + private RpcResult sendRpcWithStream(String... streams) throws Exception { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + RpcResult res = new RpcResult(); + res.successMessages = Collections.synchronizedSet(new HashSet()); + res.errorMessages = Collections.synchronizedSet(new HashSet()); + + for (String stream : streams) { + int idx = stream.lastIndexOf('/'); + ManagedBuffer meta = new NioManagedBuffer(JavaUtils.stringToBytes(stream)); + String streamName = (idx == -1) ? stream : stream.substring(idx + 1); + ManagedBuffer data = testData.openStream(conf, streamName); + client.uploadStream(meta, data, new RpcStreamCallback(stream, res, sem)); + } + + if (!sem.tryAcquire(streams.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + streamCallbacks.values().forEach(streamCallback -> { + try { + streamCallback.verify(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + client.close(); + return res; + } + + private static class RpcStreamCallback implements RpcResponseCallback { + final String streamId; + final RpcResult res; + final Semaphore sem; + + RpcStreamCallback(String streamId, RpcResult res, Semaphore sem) { + this.streamId = streamId; + this.res = res; + this.sem = sem; + } + + @Override + public void onSuccess(ByteBuffer message) { + res.successMessages.add(streamId); + sem.release(); + } + + @Override + public void onFailure(Throwable e) { + res.errorMessages.add(e.getMessage()); + sem.release(); + } + } + @Test public void singleRPC() throws Exception { RpcResult res = sendRPC("hello/Aaron"); @@ -193,10 +318,83 @@ public void sendOneWayMessage() throws Exception { } } + @Test + public void sendRpcWithStreamOneAtATime() throws Exception { + for (String stream : StreamTestHelper.STREAMS) { + RpcResult res = sendRpcWithStream(stream); + assertTrue("there were error messages!" + res.errorMessages, res.errorMessages.isEmpty()); + assertEquals(Sets.newHashSet(stream), res.successMessages); + } + } + + @Test + public void sendRpcWithStreamConcurrently() throws Exception { + String[] streams = new String[10]; + for (int i = 0; i < 10; i++) { + streams[i] = StreamTestHelper.STREAMS[i % StreamTestHelper.STREAMS.length]; + } + RpcResult res = sendRpcWithStream(streams); + assertEquals(Sets.newHashSet(StreamTestHelper.STREAMS), res.successMessages); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void sendRpcWithStreamFailures() throws Exception { + // when there is a failure reading stream data, we don't try to keep the channel usable, + // just send back a decent error msg. + RpcResult exceptionInCallbackResult = + sendRpcWithStream("fail/exception-ondata/smallBuffer", "smallBuffer"); + assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + + RpcResult nullStreamHandler = + sendRpcWithStream("fail/null/smallBuffer", "smallBuffer"); + assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + + // OTOH, if there is a failure during onComplete, the channel should still be fine + RpcResult exceptionInOnComplete = + sendRpcWithStream("fail/exception-oncomplete/smallBuffer", "smallBuffer"); + assertErrorsContain(exceptionInOnComplete.errorMessages, + Sets.newHashSet("Failure post-processing")); + assertEquals(Sets.newHashSet("smallBuffer"), exceptionInOnComplete.successMessages); + } + private void assertErrorsContain(Set errors, Set contains) { - assertEquals(contains.size(), errors.size()); + assertEquals("Expected " + contains.size() + " errors, got " + errors.size() + "errors: " + + errors, contains.size(), errors.size()); + + Pair, Set> r = checkErrorsContain(errors, contains); + assertTrue("Could not find error containing " + r.getRight() + "; errors: " + errors, + r.getRight().isEmpty()); + + assertTrue(r.getLeft().isEmpty()); + } + + private void assertErrorAndClosed(RpcResult result, String expectedError) { + assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty()); + // we expect 1 additional error, which contains *either* "closed" or "Connection reset" + Set errors = result.errorMessages; + assertEquals("Expected 2 errors, got " + errors.size() + "errors: " + + errors, 2, errors.size()); + + Set containsAndClosed = Sets.newHashSet(expectedError); + containsAndClosed.add("closed"); + containsAndClosed.add("Connection reset"); + + Pair, Set> r = checkErrorsContain(errors, containsAndClosed); + Set errorsNotFound = r.getRight(); + assertEquals(1, errorsNotFound.size()); + String err = errorsNotFound.iterator().next(); + assertTrue(err.equals("closed") || err.equals("Connection reset")); + + assertTrue(r.getLeft().isEmpty()); + } + + private Pair, Set> checkErrorsContain( + Set errors, + Set contains) { Set remainingErrors = Sets.newHashSet(errors); + Set notFound = Sets.newHashSet(); for (String contain : contains) { Iterator it = remainingErrors.iterator(); boolean foundMatch = false; @@ -207,9 +405,66 @@ private void assertErrorsContain(Set errors, Set contains) { break; } } - assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch); + if (!foundMatch) { + notFound.add(contain); + } + } + return new ImmutablePair<>(remainingErrors, notFound); + } + + private static class VerifyingStreamCallback implements StreamCallbackWithID { + final String streamId; + final StreamSuite.TestCallback helper; + final OutputStream out; + final File outFile; + + VerifyingStreamCallback(String streamId) throws IOException { + if (streamId.equals("file")) { + outFile = File.createTempFile("data", ".tmp", testData.tempDir); + out = new FileOutputStream(outFile); + } else { + out = new ByteArrayOutputStream(); + outFile = null; + } + this.streamId = streamId; + helper = new StreamSuite.TestCallback(out); + } + + void verify() throws IOException { + if (streamId.equals("file")) { + assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); + } else { + byte[] result = ((ByteArrayOutputStream)out).toByteArray(); + ByteBuffer srcBuffer = testData.srcBuffer(streamId); + ByteBuffer base; + synchronized (srcBuffer) { + base = srcBuffer.duplicate(); + } + byte[] expected = new byte[base.remaining()]; + base.get(expected); + assertEquals(expected.length, result.length); + assertTrue("buffers don't match", Arrays.equals(expected, result)); + } + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + helper.onData(streamId, buf); } - assertTrue(remainingErrors.isEmpty()); + @Override + public void onComplete(String streamId) throws IOException { + helper.onComplete(streamId); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + helper.onFailure(streamId, cause); + } + + @Override + public String getID() { + return streamId; + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index f253a07e64be1..f3050cb79cdfd 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -26,7 +26,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Random; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; @@ -37,9 +36,7 @@ import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; @@ -51,16 +48,11 @@ import org.apache.spark.network.util.TransportConf; public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + private static final String[] STREAMS = StreamTestHelper.STREAMS; + private static StreamTestHelper testData; private static TransportServer server; private static TransportClientFactory clientFactory; - private static File testFile; - private static File tempDir; - - private static ByteBuffer emptyBuffer; - private static ByteBuffer smallBuffer; - private static ByteBuffer largeBuffer; private static ByteBuffer createBuffer(int bufSize) { ByteBuffer buf = ByteBuffer.allocate(bufSize); @@ -73,23 +65,7 @@ private static ByteBuffer createBuffer(int bufSize) { @BeforeClass public static void setUp() throws Exception { - tempDir = Files.createTempDir(); - emptyBuffer = createBuffer(0); - smallBuffer = createBuffer(100); - largeBuffer = createBuffer(100000); - - testFile = File.createTempFile("stream-test-file", "txt", tempDir); - FileOutputStream fp = new FileOutputStream(testFile); - try { - Random rnd = new Random(); - for (int i = 0; i < 512; i++) { - byte[] fileContent = new byte[1024]; - rnd.nextBytes(fileContent); - fp.write(fileContent); - } - } finally { - fp.close(); - } + testData = new StreamTestHelper(); final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); final StreamManager streamManager = new StreamManager() { @@ -100,18 +76,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { @Override public ManagedBuffer openStream(String streamId) { - switch (streamId) { - case "largeBuffer": - return new NioManagedBuffer(largeBuffer); - case "smallBuffer": - return new NioManagedBuffer(smallBuffer); - case "emptyBuffer": - return new NioManagedBuffer(emptyBuffer); - case "file": - return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); - default: - throw new IllegalArgumentException("Invalid stream: " + streamId); - } + return testData.openStream(conf, streamId); } }; RpcHandler handler = new RpcHandler() { @@ -137,12 +102,7 @@ public StreamManager getStreamManager() { public static void tearDown() { server.close(); clientFactory.close(); - if (tempDir != null) { - for (File f : tempDir.listFiles()) { - f.delete(); - } - tempDir.delete(); - } + testData.cleanup(); } @Test @@ -234,21 +194,21 @@ public void run() { case "largeBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = largeBuffer; + srcBuffer = testData.largeBuffer; break; case "smallBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = smallBuffer; + srcBuffer = testData.smallBuffer; break; case "file": - outFile = File.createTempFile("data", ".tmp", tempDir); + outFile = File.createTempFile("data", ".tmp", testData.tempDir); out = new FileOutputStream(outFile); break; case "emptyBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = emptyBuffer; + srcBuffer = testData.emptyBuffer; break; default: throw new IllegalArgumentException(streamId); @@ -256,10 +216,10 @@ public void run() { TestCallback callback = new TestCallback(out); client.stream(streamId, callback); - waitForCompletion(callback); + callback.waitForCompletion(timeoutMs); if (srcBuffer == null) { - assertTrue("File stream did not match.", Files.equal(testFile, outFile)); + assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); } else { ByteBuffer base; synchronized (srcBuffer) { @@ -292,23 +252,9 @@ public void check() throws Throwable { throw error; } } - - private void waitForCompletion(TestCallback callback) throws Exception { - long now = System.currentTimeMillis(); - long deadline = now + timeoutMs; - synchronized (callback) { - while (!callback.completed && now < deadline) { - callback.wait(deadline - now); - now = System.currentTimeMillis(); - } - } - assertTrue("Timed out waiting for stream.", callback.completed); - assertNull(callback.error); - } - } - private static class TestCallback implements StreamCallback { + static class TestCallback implements StreamCallback { private final OutputStream out; public volatile boolean completed; @@ -344,6 +290,22 @@ public void onFailure(String streamId, Throwable cause) { } } + void waitForCompletion(long timeoutMs) { + long now = System.currentTimeMillis(); + long deadline = now + timeoutMs; + synchronized (this) { + while (!completed && now < deadline) { + try { + wait(deadline - now); + } catch (InterruptedException ie) { + throw new RuntimeException(ie); + } + now = System.currentTimeMillis(); + } + } + assertTrue("Timed out waiting for stream.", completed); + assertNull(error); + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java new file mode 100644 index 0000000000000..0f5c82c9e9b1f --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Random; + +import com.google.common.io.Files; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +class StreamTestHelper { + static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + + final File testFile; + final File tempDir; + + final ByteBuffer emptyBuffer; + final ByteBuffer smallBuffer; + final ByteBuffer largeBuffer; + + private static ByteBuffer createBuffer(int bufSize) { + ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + return buf; + } + + StreamTestHelper() throws Exception { + tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); + smallBuffer = createBuffer(100); + largeBuffer = createBuffer(100000); + + testFile = File.createTempFile("stream-test-file", "txt", tempDir); + FileOutputStream fp = new FileOutputStream(testFile); + try { + Random rnd = new Random(); + for (int i = 0; i < 512; i++) { + byte[] fileContent = new byte[1024]; + rnd.nextBytes(fileContent); + fp.write(fileContent); + } + } finally { + fp.close(); + } + } + + public ByteBuffer srcBuffer(String name) { + switch (name) { + case "largeBuffer": + return largeBuffer; + case "smallBuffer": + return smallBuffer; + case "emptyBuffer": + return emptyBuffer; + default: + throw new IllegalArgumentException("Invalid stream: " + name); + } + } + + public ManagedBuffer openStream(TransportConf conf, String streamId) { + switch (streamId) { + case "file": + return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); + default: + return new NioManagedBuffer(srcBuffer(streamId)); + } + } + + void cleanup() { + if (tempDir != null) { + try { + JavaUtils.deleteRecursively(tempDir); + } catch (IOException io) { + throw new RuntimeException(io); + } + } + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index fc7bba41185f0..098fa7974b87b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -138,6 +138,13 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { blockManager.applicationRemoved(appId, cleanupLocalDirs); } + /** + * Clean up any non-shuffle files in any local directories associated with an finished executor. + */ + public void executorRemoved(String executorId, String appId) { + blockManager.executorRemoved(executorId, appId); + } + /** * Register an (application, executor) with the given shuffle info. * diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index e6399897be9c2..0b7a27402369d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -24,6 +24,8 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -59,6 +61,7 @@ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); private static final ObjectMapper mapper = new ObjectMapper(); + /** * This a common prefix to the key for each app registration we stick in leveldb, so they * are easy to find, since leveldb lets you search based on prefix. @@ -66,6 +69,8 @@ public class ExternalShuffleBlockResolver { private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); + private static final Pattern MULTIPLE_SEPARATORS = Pattern.compile(File.separator + "{2,}"); + // Map containing all registered executors' metadata. @VisibleForTesting final ConcurrentMap executors; @@ -211,6 +216,26 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { } } + /** + * Removes all the non-shuffle files in any local directories associated with the finished + * executor. + */ + public void executorRemoved(String executorId, String appId) { + logger.info("Clean up non-shuffle files associated with the finished executor {}", executorId); + AppExecId fullId = new AppExecId(appId, executorId); + final ExecutorShuffleInfo executor = executors.get(fullId); + if (executor == null) { + // Executor not registered, skip clean up of the local directories. + logger.info("Executor is not registered (appId={}, execId={})", appId, executorId); + } else { + logger.info("Cleaning up non-shuffle files in executor {}'s {} local dirs", fullId, + executor.localDirs.length); + + // Execute the actual deletion in a different thread, as it may take some time. + directoryCleaner.execute(() -> deleteNonShuffleFiles(executor.localDirs)); + } + } + /** * Synchronously deletes each directory one at a time. * Should be executed in its own thread, as this may take a long time. @@ -226,6 +251,29 @@ private void deleteExecutorDirs(String[] dirs) { } } + /** + * Synchronously deletes non-shuffle files in each directory recursively. + * Should be executed in its own thread, as this may take a long time. + */ + private void deleteNonShuffleFiles(String[] dirs) { + FilenameFilter filter = new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + // Don't delete shuffle data or shuffle index files. + return !name.endsWith(".index") && !name.endsWith(".data"); + } + }; + + for (String localDir : dirs) { + try { + JavaUtils.deleteRecursively(new File(localDir), filter); + logger.debug("Successfully cleaned up non-shuffle files in directory: {}", localDir); + } catch (Exception e) { + logger.error("Failed to delete non-shuffle files in directory: " + localDir, e); + } + } + } + /** * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver, @@ -259,7 +307,8 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) int hash = JavaUtils.nonNegativeHash(filename); String localDir = localDirs[hash % localDirs.length]; int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; - return new File(new File(localDir, String.format("%02x", subDirId)), filename); + return new File(createNormalizedInternedPathname( + localDir, String.format("%02x", subDirId), filename)); } void close() { @@ -272,6 +321,28 @@ void close() { } } + /** + * This method is needed to avoid the situation when multiple File instances for the + * same pathname "foo/bar" are created, each with a separate copy of the "foo/bar" String. + * According to measurements, in some scenarios such duplicate strings may waste a lot + * of memory (~ 10% of the heap). To avoid that, we intern the pathname, and before that + * we make sure that it's in a normalized form (contains no "//", "///" etc.) Otherwise, + * the internal code in java.io.File would normalize it later, creating a new "foo/bar" + * String copy. Unfortunately, we cannot just reuse the normalization code that java.io.File + * uses, since it is in the package-private class java.io.FileSystem. + */ + @VisibleForTesting + static String createNormalizedInternedPathname(String dir1, String dir2, String fname) { + String pathname = dir1 + File.separator + dir2 + File.separator + fname; + Matcher m = MULTIPLE_SEPARATORS.matcher(pathname); + pathname = m.replaceAll("/"); + // A single trailing slash needs to be taken care of separately + if (pathname.length() > 1 && pathname.endsWith("/")) { + pathname = pathname.substring(0, pathname.length() - 1); + } + return pathname.intern(); + } + /** Simply encodes an executor's full ID, which is appId + execId. */ public static class AppExecId { public final String appId; diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 6d201b8fe8d7d..d2072a54fa415 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.io.File; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; @@ -135,4 +136,23 @@ public void jsonSerializationOfExecutorRegistration() throws IOException { "\"subDirsPerLocalDir\": 7, \"shuffleManager\": " + "\"" + SORT_MANAGER + "\"}"; assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); } + + @Test + public void testNormalizeAndInternPathname() { + assertPathsMatch("/foo", "bar", "baz", "/foo/bar/baz"); + assertPathsMatch("//foo/", "bar/", "//baz", "/foo/bar/baz"); + assertPathsMatch("foo", "bar", "baz///", "foo/bar/baz"); + assertPathsMatch("/foo/", "/bar//", "/baz", "/foo/bar/baz"); + assertPathsMatch("/", "", "", "/"); + assertPathsMatch("/", "/", "/", "/"); + } + + private void assertPathsMatch(String p1, String p2, String p3, String expectedPathname) { + String normPathname = + ExternalShuffleBlockResolver.createNormalizedInternedPathname(p1, p2, p3); + assertEquals(expectedPathname, normPathname); + File file = new File(normPathname); + String returnedPath = file.getPath(); + assertTrue(normPathname == returnedPath); + } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java new file mode 100644 index 0000000000000..d22f3ace4103b --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class NonShuffleFilesCleanupSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + private Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + private TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + private static final String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; + + @Test + public void cleanupOnRemovedExecutorWithShuffleFiles() throws IOException { + cleanupOnRemovedExecutor(true); + } + + @Test + public void cleanupOnRemovedExecutorWithoutShuffleFiles() throws IOException { + cleanupOnRemovedExecutor(false); + } + + private void cleanupOnRemovedExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + resolver.executorRemoved("exec0", "app"); + + assertCleanedUp(dataContext); + } + + @Test + public void cleanupUsesExecutorWithShuffleFiles() throws IOException { + cleanupUsesExecutor(true); + } + + @Test + public void cleanupUsesExecutorWithoutShuffleFiles() throws IOException { + cleanupUsesExecutor(false); + } + + private void cleanupUsesExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which does nothing to ensure we're actually using it. + Executor noThreadExecutor = runnable -> cleanupCalled.set(true); + + ExternalShuffleBlockResolver manager = + new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + manager.executorRemoved("exec0", "app"); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + } + + @Test + public void cleanupOnlyRemovedExecutorWithShuffleFiles() throws IOException { + cleanupOnlyRemovedExecutor(true); + } + + @Test + public void cleanupOnlyRemovedExecutorWithoutShuffleFiles() throws IOException { + cleanupOnlyRemovedExecutor(false); + } + + private void cleanupOnlyRemovedExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext0 = initDataContext(withShuffleFiles); + TestShuffleDataContext dataContext1 = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo(SORT_MANAGER)); + resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo(SORT_MANAGER)); + + + resolver.executorRemoved("exec-nonexistent", "app"); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + resolver.executorRemoved("exec0", "app"); + assertCleanedUp(dataContext0); + assertStillThere(dataContext1); + + resolver.executorRemoved("exec1", "app"); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + + // Make sure it's not an error to cleanup multiple times + resolver.executorRemoved("exec1", "app"); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithShuffleFiles() throws IOException { + cleanupOnlyRegisteredExecutor(true); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithoutShuffleFiles() throws IOException { + cleanupOnlyRegisteredExecutor(false); + } + + private void cleanupOnlyRegisteredExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + + resolver.executorRemoved("exec1", "app"); + assertStillThere(dataContext); + + resolver.executorRemoved("exec0", "app"); + assertCleanedUp(dataContext); + } + + private static void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private static FilenameFilter filter = new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + // Don't delete shuffle data or shuffle index files. + return !name.endsWith(".index") && !name.endsWith(".data"); + } + }; + + private static boolean assertOnlyShuffleDataInDir(File[] dirs) { + for (File dir : dirs) { + assertTrue(dir.getName() + " wasn't cleaned up", !dir.exists() || + dir.listFiles(filter).length == 0 || assertOnlyShuffleDataInDir(dir.listFiles())); + } + return true; + } + + private static void assertCleanedUp(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + File[] dirs = new File[] {new File(localDir)}; + assertOnlyShuffleDataInDir(dirs); + } + } + + private static TestShuffleDataContext initDataContext(boolean withShuffleFiles) + throws IOException { + if (withShuffleFiles) { + return initDataContextWithShuffleFiles(); + } else { + return initDataContextWithoutShuffleFiles(); + } + } + + private static TestShuffleDataContext initDataContextWithShuffleFiles() throws IOException { + TestShuffleDataContext dataContext = createDataContext(); + createShuffleFiles(dataContext); + createNonShuffleFiles(dataContext); + return dataContext; + } + + private static TestShuffleDataContext initDataContextWithoutShuffleFiles() throws IOException { + TestShuffleDataContext dataContext = createDataContext(); + createNonShuffleFiles(dataContext); + return dataContext; + } + + private static TestShuffleDataContext createDataContext() { + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + dataContext.create(); + return dataContext; + } + + private static void createShuffleFiles(TestShuffleDataContext dataContext) throws IOException { + Random rand = new Random(123); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { + "ABC".getBytes(StandardCharsets.UTF_8), + "DEF".getBytes(StandardCharsets.UTF_8)}); + } + + private static void createNonShuffleFiles(TestShuffleDataContext dataContext) throws IOException { + // Create spill file(s) + dataContext.insertSpillData(); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 81e01949e50fa..6989c3baf2e28 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -22,6 +22,7 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.util.UUID; import com.google.common.io.Closeables; import com.google.common.io.Files; @@ -94,6 +95,20 @@ public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) thr } } + /** Creates spill file(s) within the local dirs. */ + public void insertSpillData() throws IOException { + String filename = "temp_local_" + UUID.randomUUID(); + OutputStream dataStream = null; + + try { + dataStream = new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, filename)); + dataStream.write(42); + } finally { + Closeables.close(dataStream, false); + } + } + /** * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this * context's directories. diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index c03caf0076f61..ecd7c19f2c634 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,10 +17,12 @@ package org.apache.spark.unsafe.types; -import org.apache.spark.unsafe.Platform; - import java.util.Arrays; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + public final class ByteArray { public static final byte[] EMPTY_BYTE = new byte[0]; @@ -77,17 +79,17 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) { public static byte[] concat(byte[]... inputs) { // Compute the total length of the result - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].length; + totalLength += (long)inputs[i].length; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].length; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e9b3d9b045af5..e91fc4391425c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -29,8 +29,8 @@ import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; - import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -877,17 +877,17 @@ public UTF8String lpad(int len, UTF8String pad) { */ public static UTF8String concat(UTF8String... inputs) { // Compute the total length of the result. - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].numBytes; + totalLength += (long)inputs[i].numBytes; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it. - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java index 5d5fdc1c55a75..ef5ff8ee70ec0 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -120,6 +120,8 @@ private void check(MemoryBlock memory, Object obj, long offset, int length) { } catch (Exception expected) { Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); } + + memory.setPageNumber(MemoryBlock.NO_PAGE_NUMBER); } @Test @@ -165,11 +167,13 @@ public void testOffHeapArrayMemoryBlock() { int length = 56; check(memory, obj, offset, length); + memoryAllocator.free(memory); long address = Platform.allocateMemory(112); memory = new OffHeapMemoryBlock(address, length); obj = memory.getBaseObject(); offset = memory.getBaseOffset(); check(memory, obj, offset, length); + Platform.freeMemory(address); } } diff --git a/core/pom.xml b/core/pom.xml index 9258a856028a0..d0b869e6ef92c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -56,7 +56,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded org.apache.hadoop @@ -95,6 +95,12 @@ org.apache.curator curator-recipes + + + org.apache.zookeeper + zookeeper + @@ -344,7 +350,7 @@ net.sf.py4j py4j - 0.10.6 + 0.10.7 org.apache.spark diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 5f0045507aaab..9a767dd739b91 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -703,7 +703,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff // must be stored in the same memory page. // (8 byte key length) (key) (value) (8 byte pointer to next value) int uaoSize = UnsafeAlignedOffset.getUaoSize(); - final long recordLength = (2 * uaoSize) + klen + vlen + 8; + final long recordLength = (2L * uaoSize) + klen + vlen + 8; if (currentPage == null || currentPage.size() - pageCursor < recordLength) { if (!acquireNewPage(recordLength + uaoSize)) { return false; diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 189d91333c045..aa363eeffffb8 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -26,7 +26,7 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} +import org.apache.spark.internal.config._ import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMaster @@ -69,6 +69,10 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors * spark.dynamicAllocation.initialExecutors - Number of executors to start with * + * spark.dynamicAllocation.executorAllocationRatio - + * This is used to reduce the parallelism of the dynamic allocation that can waste + * resources when tasks are small + * * spark.dynamicAllocation.schedulerBacklogTimeout (M) - * If there are backlogged tasks for this duration, add new executors * @@ -116,9 +120,12 @@ private[spark] class ExecutorAllocationManager( // TODO: The default value of 1 for spark.executor.cores works right now because dynamic // allocation is only supported for YARN and the default number of cores per executor in YARN is // 1, but it might need to be attained differently for different cluster managers - private val tasksPerExecutor = + private val tasksPerExecutorForFullParallelism = conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1) + private val executorAllocationRatio = + conf.get(DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO) + validateSettings() // Number of executors to add in the next round @@ -209,8 +216,13 @@ private[spark] class ExecutorAllocationManager( throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } - if (tasksPerExecutor == 0) { - throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") + if (tasksPerExecutorForFullParallelism == 0) { + throw new SparkException("spark.executor.cores must not be < spark.task.cpus.") + } + + if (executorAllocationRatio > 1.0 || executorAllocationRatio <= 0.0) { + throw new SparkException( + "spark.dynamicAllocation.executorAllocationRatio must be > 0 and <= 1.0") } } @@ -273,7 +285,9 @@ private[spark] class ExecutorAllocationManager( */ private def maxNumExecutorsNeeded(): Int = { val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks - (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor + math.ceil(numRunningOrPendingTasks * executorAllocationRatio / + tasksPerExecutorForFullParallelism) + .toInt } private def totalRunningTasks(): Int = synchronized { diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index ff960b396dbf1..bcbc8df0d5865 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -74,10 +74,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses // "milliseconds" - private val slaveTimeoutMs = - sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") private val executorTimeoutMs = - sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 + sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s")}s") // "spark.network.timeoutInterval" uses "seconds", while // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 195fd4f818b36..73646051f264c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.reflect.ClassTag @@ -282,7 +282,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) } @@ -296,7 +296,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -632,9 +632,10 @@ private[spark] class MapOutputTrackerMaster( } } + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -642,7 +643,7 @@ private[spark] class MapOutputTrackerMaster( MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } case None => - Seq.empty + Iterator.empty } } @@ -669,8 +670,9 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr /** Remembers which map output locations are currently being fetched on an executor. */ private val fetching = new HashSet[Int] + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -841,6 +843,7 @@ private[spark] object MapOutputTracker extends Logging { * Given an array of map statuses and a range of map output partitions, returns a sequence that, * for each block manager ID, lists the shuffle block IDs and corresponding shuffle block sizes * stored at that block manager. + * Note that empty blocks are filtered in the result. * * If any of the statuses is null (indicating a missing location due to a failed mapper), * throws a FetchFailedException. @@ -857,22 +860,24 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] - for ((status, mapId) <- statuses.zipWithIndex) { + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { for (part <- startPartition until endPartition) { - splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + val size = status.getSizeForBlock(part) + if (size != 0) { + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), size)) + } } } } - - splitsByAddress.toSeq + splitsByAddress.iterator } } diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 477b01968c6ef..1632e0c69eef5 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -21,6 +21,7 @@ import java.io.File import java.security.NoSuchAlgorithmException import javax.net.ssl.SSLContext +import org.apache.hadoop.conf.Configuration import org.eclipse.jetty.util.ssl.SslContextFactory import org.apache.spark.internal.Logging @@ -128,7 +129,7 @@ private[spark] case class SSLOptions( } /** Returns a string representation of this SSLOptions with all the passwords masked. */ - override def toString: String = s"SSLOptions{enabled=$enabled, " + + override def toString: String = s"SSLOptions{enabled=$enabled, port=$port, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + s"trustStore=$trustStore, trustStorePassword=${trustStorePassword.map(_ => "xxx")}, " + s"protocol=$protocol, enabledAlgorithms=$enabledAlgorithms}" @@ -142,6 +143,7 @@ private[spark] object SSLOptions extends Logging { * * The following settings are allowed: * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively + * $ - `[ns].port` - the port where to bind the SSL server * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory * $ - `[ns].keyStorePassword` - a password to the key-store file * $ - `[ns].keyPassword` - a password to the private key @@ -162,11 +164,16 @@ private[spark] object SSLOptions extends Logging { * missing in SparkConf, the corresponding setting is used from the default configuration. * * @param conf Spark configuration object where the settings are collected from + * @param hadoopConf Hadoop configuration to get settings * @param ns the namespace name * @param defaults the default configuration * @return [[org.apache.spark.SSLOptions]] object */ - def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { + def parse( + conf: SparkConf, + hadoopConf: Configuration, + ns: String, + defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) val port = conf.getWithSubstitution(s"$ns.port").map(_.toInt) @@ -178,9 +185,11 @@ private[spark] object SSLOptions extends Logging { .orElse(defaults.flatMap(_.keyStore)) val keyStorePassword = conf.getWithSubstitution(s"$ns.keyStorePassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.keyStorePassword")).map(new String(_))) .orElse(defaults.flatMap(_.keyStorePassword)) val keyPassword = conf.getWithSubstitution(s"$ns.keyPassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.keyPassword")).map(new String(_))) .orElse(defaults.flatMap(_.keyPassword)) val keyStoreType = conf.getWithSubstitution(s"$ns.keyStoreType") @@ -193,6 +202,7 @@ private[spark] object SSLOptions extends Logging { .orElse(defaults.flatMap(_.trustStore)) val trustStorePassword = conf.getWithSubstitution(s"$ns.trustStorePassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.trustStorePassword")).map(new String(_))) .orElse(defaults.flatMap(_.trustStorePassword)) val trustStoreType = conf.getWithSubstitution(s"$ns.trustStoreType") diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 09ec8932353a0..3cfafeb951105 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -17,18 +17,13 @@ package org.apache.spark -import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} import java.nio.charset.StandardCharsets.UTF_8 -import java.security.{KeyStore, SecureRandom} -import java.security.cert.X509Certificate -import javax.net.ssl._ -import com.google.common.hash.HashCodes -import com.google.common.io.Files import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher @@ -89,6 +84,7 @@ private[spark] class SecurityManager( setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")); setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")); + private var secretKey: String = _ logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + "; users with view permissions: " + viewAcls.toString() + @@ -115,11 +111,14 @@ private[spark] class SecurityManager( ) } + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) // the default SSL configuration - it will be used by all communication layers unless overwritten - private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) + private val defaultSSLOptions = + SSLOptions.parse(sparkConf, hadoopConf, "spark.ssl", defaults = None) def getSSLOptions(module: String): SSLOptions = { - val opts = SSLOptions.parse(sparkConf, s"spark.ssl.$module", Some(defaultSSLOptions)) + val opts = + SSLOptions.parse(sparkConf, hadoopConf, s"spark.ssl.$module", Some(defaultSSLOptions)) logDebug(s"Created SSL options for $module: $opts") opts } @@ -321,6 +320,12 @@ private[spark] class SecurityManager( val creds = UserGroupInformation.getCurrentUser().getCredentials() Option(creds.getSecretKey(SECRET_LOOKUP_KEY)) .map { bytes => new String(bytes, UTF_8) } + // Secret key may not be found in current UGI's credentials. + // This happens when UGI is refreshed in the driver side by UGI's loginFromKeytab but not + // copy secret key from original UGI to the new one. This exists in ThriftServer's Hive + // logic. So as a workaround, storing secret key in a local variable to make it visible + // in different context. + .orElse(Option(secretKey)) .orElse(Option(sparkConf.getenv(ENV_AUTH_SECRET))) .orElse(sparkConf.getOption(SPARK_AUTH_SECRET_CONF)) .getOrElse { @@ -358,14 +363,9 @@ private[spark] class SecurityManager( return } - val rnd = new SecureRandom() - val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE - val secretBytes = new Array[Byte](length) - rnd.nextBytes(secretBytes) - + secretKey = Utils.createSecret(sparkConf) val creds = new Credentials() - val secretStr = HashCodes.fromBytes(secretBytes).toString() - creds.addSecretKey(SECRET_LOOKUP_KEY, secretStr.getBytes(UTF_8)) + creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 129956e9f9ffa..6c4c5c94cfa28 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -265,16 +265,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then seconds are assumed. * @throws java.util.NoSuchElementException If the time parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as seconds */ - def getTimeAsSeconds(key: String): Long = { + def getTimeAsSeconds(key: String): Long = catchIllegalValue(key) { Utils.timeStringAsSeconds(get(key)) } /** * Get a time parameter as seconds, falling back to a default if not set. If no * suffix is provided then seconds are assumed. + * @throws NumberFormatException If the value cannot be interpreted as seconds */ - def getTimeAsSeconds(key: String, defaultValue: String): Long = { + def getTimeAsSeconds(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.timeStringAsSeconds(get(key, defaultValue)) } @@ -282,16 +284,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then milliseconds are assumed. * @throws java.util.NoSuchElementException If the time parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as milliseconds */ - def getTimeAsMs(key: String): Long = { + def getTimeAsMs(key: String): Long = catchIllegalValue(key) { Utils.timeStringAsMs(get(key)) } /** * Get a time parameter as milliseconds, falling back to a default if not set. If no * suffix is provided then milliseconds are assumed. + * @throws NumberFormatException If the value cannot be interpreted as milliseconds */ - def getTimeAsMs(key: String, defaultValue: String): Long = { + def getTimeAsMs(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.timeStringAsMs(get(key, defaultValue)) } @@ -299,23 +303,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then bytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String): Long = { + def getSizeAsBytes(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key)) } /** * Get a size parameter as bytes, falling back to a default if not set. If no * suffix is provided then bytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String, defaultValue: String): Long = { + def getSizeAsBytes(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key, defaultValue)) } /** * Get a size parameter as bytes, falling back to a default if not set. + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String, defaultValue: Long): Long = { + def getSizeAsBytes(key: String, defaultValue: Long): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key, defaultValue + "B")) } @@ -323,16 +330,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Kibibytes */ - def getSizeAsKb(key: String): Long = { + def getSizeAsKb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsKb(get(key)) } /** * Get a size parameter as Kibibytes, falling back to a default if not set. If no * suffix is provided then Kibibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Kibibytes */ - def getSizeAsKb(key: String, defaultValue: String): Long = { + def getSizeAsKb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsKb(get(key, defaultValue)) } @@ -340,16 +349,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Mebibytes */ - def getSizeAsMb(key: String): Long = { + def getSizeAsMb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsMb(get(key)) } /** * Get a size parameter as Mebibytes, falling back to a default if not set. If no * suffix is provided then Mebibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Mebibytes */ - def getSizeAsMb(key: String, defaultValue: String): Long = { + def getSizeAsMb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsMb(get(key, defaultValue)) } @@ -357,16 +368,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Gibibytes */ - def getSizeAsGb(key: String): Long = { + def getSizeAsGb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsGb(get(key)) } /** * Get a size parameter as Gibibytes, falling back to a default if not set. If no * suffix is provided then Gibibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Gibibytes */ - def getSizeAsGb(key: String, defaultValue: String): Long = { + def getSizeAsGb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsGb(get(key, defaultValue)) } @@ -394,23 +407,35 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } - /** Get a parameter as an integer, falling back to a default if not set */ - def getInt(key: String, defaultValue: Int): Int = { + /** + * Get a parameter as an integer, falling back to a default if not set + * @throws NumberFormatException If the value cannot be interpreted as an integer + */ + def getInt(key: String, defaultValue: Int): Int = catchIllegalValue(key) { getOption(key).map(_.toInt).getOrElse(defaultValue) } - /** Get a parameter as a long, falling back to a default if not set */ - def getLong(key: String, defaultValue: Long): Long = { + /** + * Get a parameter as a long, falling back to a default if not set + * @throws NumberFormatException If the value cannot be interpreted as a long + */ + def getLong(key: String, defaultValue: Long): Long = catchIllegalValue(key) { getOption(key).map(_.toLong).getOrElse(defaultValue) } - /** Get a parameter as a double, falling back to a default if not set */ - def getDouble(key: String, defaultValue: Double): Double = { + /** + * Get a parameter as a double, falling back to a default if not ste + * @throws NumberFormatException If the value cannot be interpreted as a double + */ + def getDouble(key: String, defaultValue: Double): Double = catchIllegalValue(key) { getOption(key).map(_.toDouble).getOrElse(defaultValue) } - /** Get a parameter as a boolean, falling back to a default if not set */ - def getBoolean(key: String, defaultValue: Boolean): Boolean = { + /** + * Get a parameter as a boolean, falling back to a default if not set + * @throws IllegalArgumentException If the value cannot be interpreted as a boolean + */ + def getBoolean(key: String, defaultValue: Boolean): Boolean = catchIllegalValue(key) { getOption(key).map(_.toBoolean).getOrElse(defaultValue) } @@ -448,14 +473,33 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria */ private[spark] def getenv(name: String): String = System.getenv(name) + /** + * Wrapper method for get() methods which require some specific value format. This catches + * any [[NumberFormatException]] or [[IllegalArgumentException]] and re-raises it with the + * incorrectly configured key in the exception message. + */ + private def catchIllegalValue[T](key: String)(getValue: => T): T = { + try { + getValue + } catch { + case e: NumberFormatException => + // NumberFormatException doesn't have a constructor that takes a cause for some reason. + throw new NumberFormatException(s"Illegal value for config key $key: ${e.getMessage}") + .initCause(e) + case e: IllegalArgumentException => + throw new IllegalArgumentException(s"Illegal value for config key $key: ${e.getMessage}", e) + } + } + /** * Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ private[spark] def validateSettings() { if (contains("spark.local.dir")) { - val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + - "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)." + val msg = "Note that spark.local.dir will be overridden by the value set by " + + "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS" + + " in YARN)." logWarning(msg) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5e8595603cc90..531384ab57305 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1306,11 +1306,12 @@ class SparkContext(config: SparkConf) extends Logging { /** Build the union of a list of RDDs. */ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = withScope { - val partitioners = rdds.flatMap(_.partitioner).toSet - if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { - new PartitionerAwareUnionRDD(this, rdds) + val nonEmptyRdds = rdds.filter(!_.partitions.isEmpty) + val partitioners = nonEmptyRdds.flatMap(_.partitioner).toSet + if (nonEmptyRdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { + new PartitionerAwareUnionRDD(this, nonEmptyRdds) } else { - new UnionRDD(this, rdds) + new UnionRDD(this, nonEmptyRdds) } } @@ -1495,6 +1496,8 @@ class SparkContext(config: SparkConf) extends Logging { * @param path can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String): Unit = { addFile(path, false) @@ -1515,6 +1518,8 @@ class SparkContext(config: SparkConf) extends Logging { * use `SparkFiles.get(fileName)` to find its download location. * @param recursive if true, a directory can be given in `path`. Currently directories are * only supported for Hadoop-supported filesystems. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String, recursive: Boolean): Unit = { val uri = new Path(path).toUri @@ -1554,6 +1559,9 @@ class SparkContext(config: SparkConf) extends Logging { Utils.fetchFile(uri.toString, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, hadoopConfiguration, timestamp, useCache = false) postEnvironmentUpdate() + } else { + logWarning(s"The path $path has been added already. Overwriting of added paths " + + "is not supported in the current version.") } } @@ -1802,6 +1810,8 @@ class SparkContext(config: SparkConf) extends Logging { * * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { def addJarFile(file: File): String = { @@ -1848,6 +1858,9 @@ class SparkContext(config: SparkConf) extends Logging { if (addedJars.putIfAbsent(key, timestamp).isEmpty) { logInfo(s"Added JAR $path at $key with timestamp $timestamp") postEnvironmentUpdate() + } else { + logWarning(s"The jar $path has been added already. Overwriting of added jars " + + "is not supported in the current version.") } } } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index cccd3ea457ba4..0791fe856ef15 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,4 +178,6 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + // TODO: shall we publish it and define it in `TaskContext`? + private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index a76283e33fa65..33901bc8380e9 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -212,9 +212,15 @@ case object TaskResultLost extends TaskFailedReason { * Task was killed intentionally and needs to be rescheduled. */ @DeveloperApi -case class TaskKilled(reason: String) extends TaskFailedReason { +case class TaskKilled( + reason: String, + accumUpdates: Seq[AccumulableInfo] = Seq.empty, + private[spark] val accums: Seq[AccumulatorV2[_, _]] = Nil) + extends TaskFailedReason { + override def toErrorString: String = s"TaskKilled ($reason)" override def countTowardsTaskFailures: Boolean = false + } /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index f1936bf587282..09c83849e26b2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -668,6 +668,8 @@ class JavaSparkContext(val sc: SparkContext) * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String) { sc.addFile(path) @@ -681,6 +683,8 @@ class JavaSparkContext(val sc: SparkContext) * * A directory can be given if the recursive option is set to true. Currently directories are only * supported for Hadoop-supported filesystems. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String, recursive: Boolean): Unit = { sc.addFile(path, recursive) @@ -690,6 +694,8 @@ class JavaSparkContext(val sc: SparkContext) * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { sc.addJar(path) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala index 11f2432575d84..9ddc4a4910180 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -17,26 +17,39 @@ package org.apache.spark.api.python -import java.io.DataOutputStream -import java.net.Socket +import java.io.{DataOutputStream, File, FileOutputStream} +import java.net.InetAddress +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.Files import py4j.GatewayServer +import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** - * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port - * back to its caller via a callback port specified by the caller. + * Process that starts a Py4J GatewayServer on an ephemeral port. * * This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py). */ private[spark] object PythonGatewayServer extends Logging { initializeLogIfNecessary(true) - def main(args: Array[String]): Unit = Utils.tryOrExit { - // Start a GatewayServer on an ephemeral port - val gatewayServer: GatewayServer = new GatewayServer(null, 0) + def main(args: Array[String]): Unit = { + val secret = Utils.createSecret(new SparkConf()) + + // Start a GatewayServer on an ephemeral port. Make sure the callback client is configured + // with the same secret, in case the app needs callbacks from the JVM to the underlying + // python processes. + val localhost = InetAddress.getLoopbackAddress() + val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() + gatewayServer.start() val boundPort: Int = gatewayServer.getListeningPort if (boundPort == -1) { @@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging { logDebug(s"Started PythonGatewayServer on port $boundPort") } - // Communicate the bound port back to the caller via the caller-specified callback port - val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST") - val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt - logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort") - val callbackSocket = new Socket(callbackHost, callbackPort) - val dos = new DataOutputStream(callbackSocket.getOutputStream) + // Communicate the connection information back to the python process by writing the + // information in the requested file. This needs to match the read side in java_gateway.py. + val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH")) + val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(), + "connection", ".info").toFile() + + val dos = new DataOutputStream(new FileOutputStream(tmpPath)) dos.writeInt(boundPort) + + val secretBytes = secret.getBytes(UTF_8) + dos.writeInt(secretBytes.length) + dos.write(secretBytes, 0, secretBytes.length) dos.close() - callbackSocket.close() + + if (!tmpPath.renameTo(connectionInfoPath)) { + logError(s"Unable to write connection information to $connectionInfoPath.") + System.exit(1) + } // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: while (System.in.read() != -1) { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f6293c0dc5091..a1ee2f7d1b119 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -107,6 +108,12 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + // Authentication helper used when serving iterator data. + private lazy val authHelper = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new SocketAuthHelper(conf) + } + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) @@ -129,12 +136,13 @@ private[spark] object PythonRDD extends Logging { * (effectively a collect()), but allows you to run on a certain subset of partitions, * or to enable local execution. * - * @return the port number of a local socket which serves the data collected from this job. + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], - partitions: JArrayList[Int]): Int = { + partitions: JArrayList[Int]): Array[Any] = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = @@ -147,13 +155,14 @@ private[spark] object PythonRDD extends Logging { /** * A helper function to collect an RDD as an iterator, then serve it via socket. * - * @return the port number of a local socket which serves the data collected from this job. + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ - def collectAndServe[T](rdd: RDD[T]): Int = { + def collectAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } - def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = { + def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") } @@ -384,8 +393,11 @@ private[spark] object PythonRDD extends Logging { * and send them into this connection. * * The thread will terminate after all the data are sent or any exceptions happen. + * + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ - def serveIterator[T](items: Iterator[T], threadName: String): Int = { + def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -395,11 +407,14 @@ private[spark] object PythonRDD extends Logging { override def run() { try { val sock = serverSocket.accept() + authHelper.authClient(sock) + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) Utils.tryWithSafeFinally { writeIteratorToStream(items, out) } { out.close() + sock.close() } } catch { case NonFatal(e) => @@ -410,7 +425,7 @@ private[spark] object PythonRDD extends Logging { } }.start() - serverSocket.getLocalPort + Array(serverSocket.getLocalPort, authHelper.secret) } private def getMergedConf(confAsMap: java.util.HashMap[String, String], diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f075a7e0eb0b4..ebabedf950e39 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -40,6 +40,7 @@ private[spark] object PythonEvalType { val SQL_SCALAR_PANDAS_UDF = 200 val SQL_GROUPED_MAP_PANDAS_UDF = 201 val SQL_GROUPED_AGG_PANDAS_UDF = 202 + val SQL_WINDOW_AGG_PANDAS_UDF = 203 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -47,6 +48,7 @@ private[spark] object PythonEvalType { case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" + case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" } } @@ -183,6 +185,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) dataOut.writeLong(context.taskAttemptId()) + val localProps = context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala + dataOut.writeInt(localProps.size) + localProps.foreach { case (k, v) => + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 92e228a9dd10c..27a5e19f96a14 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 2340580b54f67..6afa37aa36fd3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -27,6 +27,7 @@ import scala.collection.mutable import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util.{RedirectThread, Utils} private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) @@ -67,6 +68,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String value }.getOrElse("pyspark.worker") + private val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 @@ -108,6 +111,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String if (pid < 0) { throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } + + authHelper.authToServer(socket) daemonWorkers.put(socket, pid) socket } @@ -145,25 +150,24 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") + workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) val worker = pb.start() // Redirect worker stdout and stderr redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) - // Tell the worker our port - val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8) - out.write(serverSocket.getLocalPort + "\n") - out.flush() - - // Wait for it to connect to our socket + // Wait for it to connect to our socket, and validate the auth secret. serverSocket.setSoTimeout(10000) + try { val socket = serverSocket.accept() + authHelper.authClient(socket) simpleWorkers.put(socket, worker) return socket } catch { case e: Exception => - throw new SparkException("Python worker did not connect back in time", e) + throw new SparkException("Python worker failed to connect back.", e) } } finally { if (serverSocket != null) { @@ -187,6 +191,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") daemon = pb.start() @@ -218,7 +223,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Redirect daemon stdout and stderr redirectStreamsToStderr(in, daemon.getErrorStream) - } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala new file mode 100644 index 0000000000000..ac6826a9ec774 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.net.Socket + +import org.apache.spark.SparkConf +import org.apache.spark.security.SocketAuthHelper + +private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) { + + override protected def readUtf8(s: Socket): String = { + SerDe.readString(new DataInputStream(s.getInputStream())) + } + + override protected def writeUtf8(str: String, s: Socket): Unit = { + val out = s.getOutputStream() + SerDe.writeString(new DataOutputStream(out), str) + out.flush() + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 2d1152a036449..3b2e809408e0f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -17,8 +17,8 @@ package org.apache.spark.api.r -import java.io.{DataOutputStream, File, FileOutputStream, IOException} -import java.net.{InetAddress, InetSocketAddress, ServerSocket} +import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap @@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils /** * Netty-based backend server that is used to communicate between R and Java. @@ -45,7 +47,7 @@ private[spark] class RBackend { /** Tracks JVM objects returned to R for this RBackend instance. */ private[r] val jvmObjectTracker = new JVMObjectTracker - def init(): Int = { + def init(): (Int, RAuthHelper) = { val conf = new SparkConf() val backendConnectionTimeout = conf.getInt( "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) @@ -53,6 +55,7 @@ private[spark] class RBackend { conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS)) val workerGroup = bossGroup val handler = new RBackendHandler(this) + val authHelper = new RAuthHelper(conf) bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) @@ -71,13 +74,16 @@ private[spark] class RBackend { new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout)) + .addLast(new RBackendAuthHandler(authHelper.secret)) .addLast("handler", handler) } }) channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0)) channelFuture.syncUninterruptibly() - channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + + val port = channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + (port, authHelper) } def run(): Unit = { @@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging { val sparkRBackend = new RBackend() try { // bind to random port - val boundPort = sparkRBackend.init() + val (boundPort, authHelper) = sparkRBackend.init() val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() // Connection timeout is set by socket client. To make it configurable we will pass the @@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging { dos.writeInt(listenPort) SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) dos.writeInt(backendConnectionTimeout) + SerDe.writeString(dos, authHelper.secret) dos.close() f.renameTo(new File(path)) @@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging { val buf = new Array[Byte](1024) // shutdown JVM if R does not connect back in 10 seconds serverSocket.setSoTimeout(10000) + + // Wait for the R process to connect back, ignoring any failed auth attempts. Allow + // a max number of connection attempts to avoid looping forever. try { - val inSocket = serverSocket.accept() + var remainingAttempts = 10 + var inSocket: Socket = null + while (inSocket == null) { + inSocket = serverSocket.accept() + try { + authHelper.authClient(inSocket) + } catch { + case e: Exception => + remainingAttempts -= 1 + if (remainingAttempts == 0) { + val msg = "Too many failed authentication attempts." + logError(msg) + throw new IllegalStateException(msg) + } + logInfo("Client connection failed authentication.") + inSocket = null + } + } + serverSocket.close() + // wait for the end of socket, closed if R process die inSocket.getInputStream().read(buf) } finally { + serverSocket.close() sparkRBackend.close() System.exit(0) } @@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging { } System.exit(0) } + } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala new file mode 100644 index 0000000000000..4162e4a6c7476 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.nio.charset.StandardCharsets.UTF_8 + +import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Authentication handler for connections from the R process. + */ +private class RBackendAuthHandler(secret: String) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + // The R code adds a null terminator to serialized strings, so ignore it here. + val clientSecret = new String(msg, 0, msg.length - 1, UTF_8) + try { + require(secret == clientSecret, "Auth secret mismatch.") + ctx.pipeline().remove(this) + writeReply("ok", ctx.channel()) + } catch { + case e: Exception => + logInfo("Authentication failure.", e) + writeReply("err", ctx.channel()) + ctx.close() + } + } + + private def writeReply(reply: String, chan: Channel): Unit = { + val out = new ByteArrayOutputStream() + SerDe.writeString(new DataOutputStream(out), reply) + chan.writeAndFlush(out.toByteArray()) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 88118392003e8..e7fdc3963945a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -74,14 +74,19 @@ private[spark] class RRunner[U]( // the socket used to send out the input of task serverSocket.setSoTimeout(10000) - val inSocket = serverSocket.accept() - startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) - - // the socket used to receive the output of task - val outSocket = serverSocket.accept() - val inputStream = new BufferedInputStream(outSocket.getInputStream) - dataStream = new DataInputStream(inputStream) - serverSocket.close() + dataStream = try { + val inSocket = serverSocket.accept() + RRunner.authHelper.authClient(inSocket) + startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + RRunner.authHelper.authClient(outSocket) + val inputStream = new BufferedInputStream(outSocket.getInputStream) + new DataInputStream(inputStream) + } finally { + serverSocket.close() + } try { return new Iterator[U] { @@ -315,6 +320,11 @@ private[r] object RRunner { private[this] var errThread: BufferedStreamThread = _ private[this] var daemonChannel: DataOutputStream = _ + private lazy val authHelper = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new RAuthHelper(conf) + } + /** * Start a thread to print the process's stderr to ours */ @@ -349,6 +359,7 @@ private[r] object RRunner { pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory()) pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE") + pb.environment().put("SPARKR_WORKER_SECRET", authHelper.secret) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) @@ -370,8 +381,12 @@ private[r] object RRunner { // the socket used to send out the input of task serverSocket.setSoTimeout(10000) val sock = serverSocket.accept() - daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - serverSocket.close() + try { + authHelper.authClient(sock) + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + } finally { + serverSocket.close() + } } try { daemonChannel.writeInt(port) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index f975fa5cb4e23..b59a4fe66587c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -94,6 +94,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana blockHandler.applicationRemoved(appId, true /* cleanupLocalDirs */) } + /** Clean up all the non-shuffle files associated with an executor that has exited. */ + def executorRemoved(executorId: String, appId: String): Unit = { + blockHandler.executorRemoved(executorId, appId) + } + def stop() { if (server != null) { server.close() diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 7aca305783a7f..ccb30e205ca40 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -18,7 +18,8 @@ package org.apache.spark.deploy import java.io.File -import java.net.URI +import java.net.{InetAddress, URI} +import java.nio.file.Files import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -39,6 +40,7 @@ object PythonRunner { val pyFiles = args(1) val otherArgs = args.slice(2, args.length) val sparkConf = new SparkConf() + val secret = Utils.createSecret(sparkConf) val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON) .orElse(sparkConf.get(PYSPARK_PYTHON)) .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON")) @@ -47,11 +49,17 @@ object PythonRunner { // Format python file paths before adding them to the PYTHONPATH val formattedPythonFile = formatPath(pythonFile) - val formattedPyFiles = formatPaths(pyFiles) + val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles)) // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such - val gatewayServer = new py4j.GatewayServer(null, 0) + val localhost = InetAddress.getLoopbackAddress() + val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() val thread = new Thread(new Runnable() { override def run(): Unit = Utils.logUncaughtExceptions { gatewayServer.start() @@ -82,6 +90,7 @@ object PythonRunner { // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) + env.put("PYSPARK_GATEWAY_SECRET", secret) // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) @@ -145,4 +154,30 @@ object PythonRunner { .map { p => formatPath(p, testWindows) } } + /** + * Resolves the ".py" files. ".py" file should not be added as is because PYTHONPATH does + * not expect a file. This method creates a temporary directory and puts the ".py" files + * if exist in the given paths. + */ + private def resolvePyFiles(pyFiles: Array[String]): Array[String] = { + lazy val dest = Utils.createTempDir(namePrefix = "localPyFiles") + pyFiles.flatMap { pyFile => + // In case of client with submit, the python paths should be set before context + // initialization because the context initialization can be done later. + // We will copy the local ".py" files because ".py" file shouldn't be added + // alone but its parent directory in PYTHONPATH. See SPARK-24384. + if (pyFile.endsWith(".py")) { + val source = new File(pyFile) + if (source.exists() && source.isFile && source.canRead) { + Files.copy(source.toPath, new File(dest, source.getName).toPath) + Some(dest.getAbsolutePath) + } else { + // Don't have to add it if it doesn't exist or isn't readable. + None + } + } else { + Some(pyFile) + } + }.distinct + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index 6eb53a8252205..e86b362639e57 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -68,10 +68,13 @@ object RRunner { // Java system properties etc. val sparkRBackend = new RBackend() @volatile var sparkRBackendPort = 0 + @volatile var sparkRBackendSecret: String = null val initialized = new Semaphore(0) val sparkRBackendThread = new Thread("SparkR backend") { override def run() { - sparkRBackendPort = sparkRBackend.init() + val (port, authHelper) = sparkRBackend.init() + sparkRBackendPort = port + sparkRBackendSecret = authHelper.secret initialized.release() sparkRBackend.run() } @@ -91,6 +94,7 @@ object RRunner { env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) env.put("R_PROFILE_USER", Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator)) + env.put("SPARKR_BACKEND_AUTH_SECRET", sparkRBackendSecret) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 427c797755b84..e7310ee886103 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -22,6 +22,7 @@ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowab import java.net.URL import java.security.PrivilegedExceptionAction import java.text.ParseException +import java.util.UUID import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} @@ -284,8 +285,6 @@ private[spark] class SparkSubmit extends Logging { case (STANDALONE, CLUSTER) if args.isR => error("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") - case (KUBERNETES, _) if args.isPython => - error("Python applications are currently not supported for Kubernetes.") case (KUBERNETES, _) if args.isR => error("R applications are currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => @@ -309,6 +308,7 @@ private[spark] class SparkSubmit extends Logging { val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER val isKubernetesCluster = clusterManager == KUBERNETES && deployMode == CLUSTER + val isMesosClient = clusterManager == MESOS && deployMode == CLIENT if (!isMesosCluster && !isStandAloneCluster) { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files @@ -336,7 +336,7 @@ private[spark] class SparkSubmit extends Logging { val targetDir = Utils.createTempDir() // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { + if (clusterManager == YARN || clusterManager == LOCAL || isMesosClient) { if (args.principal != null) { if (args.keytab != null) { require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") @@ -385,7 +385,7 @@ private[spark] class SparkSubmit extends Logging { val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES) def shouldDownload(scheme: String): Boolean = { - forceDownloadSchemes.contains(scheme) || + forceDownloadSchemes.contains("*") || forceDownloadSchemes.contains(scheme) || Try { FileSystem.getFileSystemClass(scheme, hadoopConf) }.isFailure } @@ -428,18 +428,15 @@ private[spark] class SparkSubmit extends Logging { // Usage: PythonAppRunner
[app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" args.childArgs = ArrayBuffer(localPrimaryResource, localPyFiles) ++ args.childArgs - if (clusterManager != YARN) { - // The YARN backend distributes the primary file differently, so don't merge it. - args.files = mergeFileLists(args.files, args.primaryResource) - } } if (clusterManager != YARN) { // The YARN backend handles python files differently, so don't merge the lists. args.files = mergeFileLists(args.files, args.pyFiles) } - if (localPyFiles != null) { - sparkConf.set("spark.submit.pyFiles", localPyFiles) - } + } + + if (localPyFiles != null) { + sparkConf.set("spark.submit.pyFiles", localPyFiles) } // In YARN mode for an R app, add the SparkR package archive and the R package @@ -581,7 +578,8 @@ private[spark] class SparkSubmit extends Logging { } // Add the main application jar and any added jars to classpath in case YARN client // requires these jars. - // This assumes both primaryResource and user jars are local jars, otherwise it will not be + // This assumes both primaryResource and user jars are local jars, or already downloaded + // to local by configuring "spark.yarn.dist.forceDownloadSchemes", otherwise it will not be // added to the classpath of YARN client. if (isYarnCluster) { if (isUserJar(args.primaryResource)) { @@ -695,9 +693,19 @@ private[spark] class SparkSubmit extends Logging { if (isKubernetesCluster) { childMainClass = KUBERNETES_CLUSTER_SUBMIT_CLASS if (args.primaryResource != SparkLauncher.NO_RESOURCE) { - childArgs ++= Array("--primary-java-resource", args.primaryResource) + if (args.isPython) { + childArgs ++= Array("--primary-py-file", args.primaryResource) + childArgs ++= Array("--main-class", "org.apache.spark.deploy.PythonRunner") + if (args.pyFiles != null) { + childArgs ++= Array("--other-py-files", args.pyFiles) + } + } else { + childArgs ++= Array("--primary-java-resource", args.primaryResource) + childArgs ++= Array("--main-class", args.mainClass) + } + } else { + childArgs ++= Array("--main-class", args.mainClass) } - childArgs ++= Array("--main-class", args.mainClass) if (args.childArgs != null) { args.childArgs.foreach { arg => childArgs += ("--arg", arg) @@ -1204,7 +1212,33 @@ private[spark] object SparkSubmitUtils { /** A nice function to use in tests as well. Values are dummy strings. */ def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( - ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) + // Include UUID in module name, so multiple clients resolving maven coordinate at the same time + // do not modify the same resolution file concurrently. + ModuleRevisionId.newInstance("org.apache.spark", + s"spark-submit-parent-${UUID.randomUUID.toString}", + "1.0")) + + /** + * Clear ivy resolution from current launch. The resolution file is usually at + * ~/.ivy2/org.apache.spark-spark-submit-parent-$UUID-default.xml, + * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.xml, and + * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.properties. + * Since each launch will have its own resolution files created, delete them after + * each resolution to prevent accumulation of these files in the ivy cache dir. + */ + private def clearIvyResolutionFiles( + mdId: ModuleRevisionId, + ivySettings: IvySettings, + ivyConfName: String): Unit = { + val currentResolutionFiles = Seq( + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml", + s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.xml", + s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.properties" + ) + currentResolutionFiles.foreach { filename => + new File(ivySettings.getDefaultCache, filename).delete() + } + } /** * Resolves any dependencies that were supplied through maven coordinates @@ -1255,14 +1289,6 @@ private[spark] object SparkSubmitUtils { // A Module descriptor must be specified. Entries are dummy strings val md = getModuleDescriptor - // clear ivy resolution from previous launches. The resolution file is usually at - // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file - // leads to confusion with Ivy when the files can no longer be found at the repository - // declared in that file/ - val mdId = md.getModuleRevisionId - val previousResolution = new File(ivySettings.getDefaultCache, - s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") - if (previousResolution.exists) previousResolution.delete md.setDefaultConf(ivyConfName) @@ -1283,7 +1309,10 @@ private[spark] object SparkSubmitUtils { packagesDirectory.getAbsolutePath + File.separator + "[organization]_[artifact]-[revision](-[classifier]).[ext]", retrieveOptions.setConfs(Array(ivyConfName))) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + val paths = resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + val mdId = md.getModuleRevisionId + clearIvyResolutionFiles(mdId, ivySettings, ivyConfName) + paths } finally { System.setOut(sysOut) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0733fdb72cafb..fb232101114b9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -36,7 +36,6 @@ import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils - /** * Parses and encapsulates arguments from the spark-submit script. * The env argument is used for testing. @@ -76,6 +75,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var proxyUser: String = null var principal: String = null var keytab: String = null + private var dynamicAllocationEnabled: Boolean = false // Standalone cluster mode only var supervise: Boolean = false @@ -182,6 +182,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull files = Option(files).orElse(sparkProperties.get("spark.files")).orNull + pyFiles = Option(pyFiles).orElse(sparkProperties.get("spark.submit.pyFiles")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull ivySettingsPath = sparkProperties.get("spark.jars.ivySettings") packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull @@ -198,6 +199,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull + dynamicAllocationEnabled = + sparkProperties.get("spark.dynamicAllocation.enabled").exists("true".equalsIgnoreCase) // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { @@ -274,12 +277,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { error("Total executor cores must be a positive number") } - if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { + if (!dynamicAllocationEnabled && + numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { error("Number of executors must be a positive number") } - if (pyFiles != null && !isPython) { - error("--py-files given but primary resource is not a Python script") - } if (master.startsWith("yarn")) { val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 56db9359e033f..bf1eeb0c1bf59 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,6 +18,8 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} +import java.nio.file.Files +import java.nio.file.attribute.PosixFilePermissions import java.util.{Date, ServiceLoader} import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -130,8 +132,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => - require(path.isDirectory(), s"Configured store directory ($path) does not exist.") - val dbPath = new File(path, "listing.ldb") + val perms = PosixFilePermissions.fromString("rwx------") + val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath(), + PosixFilePermissions.asFileAttribute(perms)).toFile() + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, AppStatusStore.CURRENT_VERSION, logDir.toString()) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 6fc12d721e6f1..32667ddf5c7ea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,8 +37,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = - ++ - + ++ +
- UIUtils.basicSparkPage(content, "History Server", true) + UIUtils.basicSparkPage(request, content, "History Server", true) } - private def makePageLink(showIncomplete: Boolean): String = { - UIUtils.prependBaseUri("/?" + "showIncomplete=" + showIncomplete) + private def makePageLink(request: HttpServletRequest, showIncomplete: Boolean): String = { + UIUtils.prependBaseUri(request, "/?" + "showIncomplete=" + showIncomplete) } private def isApplicationCompleted(appInfo: ApplicationInfo): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 611fa563a7cd9..56f3f59504a7d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -87,7 +87,7 @@ class HistoryServer( if (!loadAppUi(appId, None) && (!attemptId.isDefined || !loadAppUi(appId, attemptId))) { val msg =
Application {appId} not found.
res.setStatus(HttpServletResponse.SC_NOT_FOUND) - UIUtils.basicSparkPage(msg, "Not Found").foreach { n => + UIUtils.basicSparkPage(req, msg, "Not Found").foreach { n => res.getWriter().write(n.toString) } return @@ -124,7 +124,7 @@ class HistoryServer( attachHandler(ApiRootResource.getServletHandler(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) val contextHandler = new ServletContextHandler contextHandler.setContextPath(HistoryServer.UI_PATH_PREFIX) @@ -152,7 +152,6 @@ class HistoryServer( assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") handlers.synchronized { ui.getHandlers.foreach(attachHandler) - addFilters(ui.getHandlers, conf) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index f699c75085fe1..fad4e46dc035d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -40,7 +40,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") .getOrElse(state.completedApps.find(_.id == appId).orNull) if (app == null) { val msg =
No running application with ID {appId}
- return UIUtils.basicSparkPage(msg, "Not Found") + return UIUtils.basicSparkPage(request, msg, "Not Found") } val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") @@ -127,7 +127,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") } ; - UIUtils.basicSparkPage(content, "Application: " + app.desc.name) + UIUtils.basicSparkPage(request, content, "Application: " + app.desc.name) } private def executorRow(executor: ExecutorDesc): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index c629937606b51..b8afe203fbfa2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -215,7 +215,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } ; - UIUtils.basicSparkPage(content, "Spark Master at " + state.uri) + UIUtils.basicSparkPage(request, content, "Spark Master at " + state.uri) } private def workerRow(worker: WorkerInfo): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 35b7ddd46e4db..e87b2240564bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -43,7 +43,7 @@ class MasterWebUI( val masterPage = new MasterPage(this) attachPage(new ApplicationPage(this)) attachPage(masterPage) - attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler( "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 742a95841a138..31a8e3e60c067 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -233,30 +233,44 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { import scala.concurrent.ExecutionContext.Implicits.global val responseFuture = Future { - val dataStream = - if (connection.getResponseCode == HttpServletResponse.SC_OK) { - connection.getInputStream - } else { - connection.getErrorStream + val responseCode = connection.getResponseCode + + if (responseCode != HttpServletResponse.SC_OK) { + val errString = Some(Source.fromInputStream(connection.getErrorStream()) + .getLines().mkString("\n")) + if (responseCode == HttpServletResponse.SC_INTERNAL_SERVER_ERROR && + !connection.getContentType().contains("application/json")) { + throw new SubmitRestProtocolException(s"Server responded with exception:\n${errString}") + } + logError(s"Server responded with error:\n${errString}") + val error = new ErrorResponse + if (responseCode == RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) { + error.highestProtocolVersion = RestSubmissionServer.PROTOCOL_VERSION + } + error.message = errString.get + error + } else { + val dataStream = connection.getInputStream + + // If the server threw an exception while writing a response, it will not have a body + if (dataStream == null) { + throw new SubmitRestProtocolException("Server returned empty body") + } + val responseJson = Source.fromInputStream(dataStream).mkString + logDebug(s"Response from the server:\n$responseJson") + val response = SubmitRestProtocolMessage.fromJson(responseJson) + response.validate() + response match { + // If the response is an error, log the message + case error: ErrorResponse => + logError(s"Server responded with error:\n${error.message}") + error + // Otherwise, simply return the response + case response: SubmitRestProtocolResponse => response + case unexpected => + throw new SubmitRestProtocolException( + s"Message received from server was not a response:\n${unexpected.toJson}") } - // If the server threw an exception while writing a response, it will not have a body - if (dataStream == null) { - throw new SubmitRestProtocolException("Server returned empty body") - } - val responseJson = Source.fromInputStream(dataStream).mkString - logDebug(s"Response from the server:\n$responseJson") - val response = SubmitRestProtocolMessage.fromJson(responseJson) - response.validate() - response match { - // If the response is an error, log the message - case error: ErrorResponse => - logError(s"Server responded with error:\n${error.message}") - error - // Otherwise, simply return the response - case response: SubmitRestProtocolResponse => response - case unexpected => - throw new SubmitRestProtocolException( - s"Message received from server was not a response:\n${unexpected.toJson}") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index e88195d95f270..3d99d085408c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -94,6 +94,7 @@ private[spark] abstract class RestSubmissionServer( new HttpConnectionFactory()) connector.setHost(host) connector.setPort(startPort) + connector.setReuseAddress(!Utils.isWindows) server.addConnector(connector) val mainHandler = new ServletContextHandler diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 5151df00476f9..ab8d8d96a9b08 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging * * Also, each HadoopDelegationTokenProvider is controlled by * spark.security.credentials.{service}.enabled, and will not be loaded if this config is set to - * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be + * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be * enabled/disabled by the configuration spark.security.credentials.hive.enabled. * * @param sparkConf Spark configuration @@ -52,7 +52,7 @@ private[spark] class HadoopDelegationTokenManager( // Maintain all the registered delegation token providers private val delegationTokenProviders = getDelegationTokenProviders - logDebug(s"Using the following delegation token providers: " + + logDebug("Using the following builtin delegation token providers: " + s"${delegationTokenProviders.keys.mkString(", ")}.") /** Construct a [[HadoopDelegationTokenManager]] for the default Hadoop filesystem */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 58a181128eb4d..a6d13d12fc28d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -225,7 +225,7 @@ private[deploy] class DriverRunner( // check if attempting another run keepTrying = supervise && exitCode != 0 && !killed if (keepTrying) { - if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) { + if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000L) { waitSeconds = 1 } logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index d4d8521cc8204..dc6a3076a5113 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import com.google.common.io.Files import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} +import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef @@ -142,7 +142,11 @@ private[deploy] class ExecutorRunner( private def fetchAndRunExecutor() { try { // Launch the process - val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), + val subsOpts = appDesc.command.javaOpts.map { + Utils.substituteAppNExecIds(_, appId, execId.toString) + } + val subsCommand = appDesc.command.copy(javaOpts = subsOpts) + val builder = CommandUtils.buildProcessBuilder(subsCommand, new SecurityManager(conf), memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 563b84934f264..ee1ca0bba5749 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -23,6 +23,7 @@ import java.text.SimpleDateFormat import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} +import java.util.function.Supplier import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext @@ -49,7 +50,8 @@ private[deploy] class Worker( endpointName: String, workDirPath: String = null, val conf: SparkConf, - val securityMgr: SecurityManager) + val securityMgr: SecurityManager, + externalShuffleServiceSupplier: Supplier[ExternalShuffleService] = null) extends ThreadSafeRpcEndpoint with Logging { private val host = rpcEnv.address.host @@ -97,6 +99,10 @@ private[deploy] class Worker( private val APP_DATA_RETENTION_SECONDS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) + // Whether or not cleanup the non-shuffle files on executor exits. + private val CLEANUP_NON_SHUFFLE_FILES_ENABLED = + conf.getBoolean("spark.storage.cleanupFilesAfterExecutorExit", true) + private val testing: Boolean = sys.props.contains("spark.testing") private var master: Option[RpcEndpointRef] = None @@ -142,7 +148,11 @@ private[deploy] class Worker( WorkerWebUI.DEFAULT_RETAINED_DRIVERS) // The shuffle service is not actually started unless configured. - private val shuffleService = new ExternalShuffleService(conf, securityMgr) + private val shuffleService = if (externalShuffleServiceSupplier != null) { + externalShuffleServiceSupplier.get() + } else { + new ExternalShuffleService(conf, securityMgr) + } private val publicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") @@ -732,6 +742,9 @@ private[deploy] class Worker( trimFinishedExecutorsIfNecessary() coresUsed -= executor.cores memoryUsed -= executor.memory + if (CLEANUP_NON_SHUFFLE_FILES_ENABLED) { + shuffleService.executorRemoved(executorStateChanged.execId.toString, appId) + } case None => logInfo("Unknown Executor " + fullId + " finished with state " + state + message.map(" message " + _).getOrElse("") + diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 2f5a5642d3cab..4fca9342c0378 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -118,7 +118,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with - UIUtils.basicSparkPage(content, logType + " log page for " + pageName) + UIUtils.basicSparkPage(request, content, logType + " log page for " + pageName) } /** Get the part of the log files given the offset and desired length of bytes */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 8b98ae56fc108..aa4e28d213e2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -135,7 +135,7 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { } ; - UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( + UIUtils.basicSparkPage(request, content, "Spark Worker at %s:%s".format( workerState.host, workerState.port)) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index db696b04384bd..ea67b7434a769 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -47,7 +47,7 @@ class WorkerWebUI( val logPage = new LogPage(this) attachPage(logPage) attachPage(new WorkerPage(this)) - attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) + addStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE) attachHandler(createServletHandler("/log", (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr, diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c325222b764b8..b1856ff0f3247 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -287,6 +287,28 @@ private[spark] class Executor( notifyAll() } + /** + * Utility function to: + * 1. Report executor runtime and JVM gc time if possible + * 2. Collect accumulator updates + * 3. Set the finished flag to true and clear current thread's interrupt status + */ + private def collectAccumulatorsAndResetStatusOnFailure(taskStartTime: Long) = { + // Report executor runtime and JVM gc time + Option(task).foreach(t => { + t.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStartTime) + t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) + }) + + // Collect latest accumulator values to report back to the driver + val accums: Seq[AccumulatorV2[_, _]] = + Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty) + val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + + setTaskFinishedAndClearInterruptStatus() + (accums, accUpdates) + } + override def run(): Unit = { threadId = Thread.currentThread.getId Thread.currentThread.setName(threadName) @@ -300,7 +322,7 @@ private[spark] class Executor( val ser = env.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var taskStart: Long = 0 + var taskStartTime: Long = 0 var taskStartCpu: Long = 0 startGCTime = computeTotalGcTime() @@ -336,7 +358,7 @@ private[spark] class Executor( } // Run the actual task and measure its runtime. - taskStart = System.currentTimeMillis() + taskStartTime = System.currentTimeMillis() taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime } else 0L @@ -396,11 +418,11 @@ private[spark] class Executor( // Deserialization happens in two parts: first, we deserialize a Task object, which // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. task.metrics.setExecutorDeserializeTime( - (taskStart - deserializeStartTime) + task.executorDeserializeTime) + (taskStartTime - deserializeStartTime) + task.executorDeserializeTime) task.metrics.setExecutorDeserializeCpuTime( (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime) // We need to subtract Task.run()'s deserialization time to avoid double-counting - task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) + task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - task.executorDeserializeTime) task.metrics.setExecutorCpuTime( (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime) task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) @@ -482,16 +504,19 @@ private[spark] class Executor( } catch { case t: TaskKilledException => logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) + val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums)) + execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case _: InterruptedException | NonFatal(_) if task != null && task.reasonIfKilled.isDefined => val killReason = task.reasonIfKilled.getOrElse("unknown reason") logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate( - taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) + val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums)) + execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason @@ -524,17 +549,7 @@ private[spark] class Executor( // the task failure would not be ignored if the shutdown happened because of premption, // instead of an app issue). if (!ShutdownHookManager.inShutdown()) { - // Collect latest accumulator values to report back to the driver - val accums: Seq[AccumulatorV2[_, _]] = - if (task != null) { - task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart) - task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) - task.collectAccumulatorUpdates(taskFailed = true) - } else { - Seq.empty - } - - val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) val serializedTaskEndReason = { try { diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index f47cd38d712c3..04c5c4b90e8a1 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -53,6 +53,19 @@ private[spark] class WholeTextFileInputFormat val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong + + // For small files we need to ensure the min split size per node & rack <= maxSplitSize + val config = context.getConfiguration + val minSplitSizePerNode = config.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERNODE, 0L) + val minSplitSizePerRack = config.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERRACK, 0L) + + if (maxSplitSize < minSplitSizePerNode) { + super.setMinSplitSizeNode(maxSplitSize) + } + + if (maxSplitSize < minSplitSizePerRack) { + super.setMinSplitSizeRack(maxSplitSize) + } super.setMaxSplitSize(maxSplitSize) } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 99d779fb600e8..ba892bf7f60d6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -72,6 +72,9 @@ package object config { private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) + private[spark] val EVENT_LOG_CALLSITE_FORM = + ConfigBuilder("spark.eventLog.callsite").stringConf.createWithDefault("short") + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional @@ -126,6 +129,10 @@ package object config { private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) + private[spark] val DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO = + ConfigBuilder("spark.dynamicAllocation.executorAllocationRatio") + .doubleConf.createWithDefault(1.0) + private[spark] val LOCALITY_WAIT = ConfigBuilder("spark.locality.wait") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("3s") @@ -338,7 +345,7 @@ package object config { "a property key or value, the value is redacted from the environment UI and various logs " + "like YARN and event logs.") .regexConf - .createWithDefault("(?i)secret|password|url|user|username".r) + .createWithDefault("(?i)secret|password".r) private[spark] val STRING_REDACTION_PATTERN = ConfigBuilder("spark.redaction.string.regex") @@ -348,6 +355,11 @@ package object config { .regexConf .createOptional + private[spark] val AUTH_SECRET_BIT_LENGTH = + ConfigBuilder("spark.authenticate.secretBitLength") + .intConf + .createWithDefault(256) + private[spark] val NETWORK_AUTH_ENABLED = ConfigBuilder("spark.authenticate") .booleanConf @@ -474,10 +486,11 @@ package object config { private[spark] val FORCE_DOWNLOAD_SCHEMES = ConfigBuilder("spark.yarn.dist.forceDownloadSchemes") - .doc("Comma-separated list of schemes for which files will be downloaded to the " + + .doc("Comma-separated list of schemes for which resources will be downloaded to the " + "local disk prior to being added to YARN's distributed cache. For use in cases " + "where the YARN service does not support schemes that are supported by Spark, like http, " + - "https and ftp.") + "https and ftp, or jars required to be in the local YARN client's classpath. Wildcard " + + "'*' is denoted to download resources for all the schemes.") .stringConf .toSequence .createWithDefault(Nil) @@ -543,4 +556,11 @@ package object config { .timeConf(TimeUnit.SECONDS) .createWithDefaultString("1h") + private[spark] val SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS = + ConfigBuilder("spark.shuffle.minNumPartitionsToHighlyCompress") + .internal() + .doc("Number of partitions to determine if MapStatus should use HighlyCompressedMapStatus") + .intConf + .checkValue(v => v > 0, "The value should be a positive integer.") + .createWithDefault(2000) } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala index ddbd624b380d4..af0aa41518766 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala @@ -31,6 +31,8 @@ class HadoopMapRedCommitProtocol(jobId: String, path: String) override def setupCommitter(context: NewTaskAttemptContext): OutputCommitter = { val config = context.getConfiguration.asInstanceOf[JobConf] - config.getOutputCommitter + val committer = config.getOutputCommitter + logInfo(s"Using output committer class ${committer.getClass.getCanonicalName}") + committer } } diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index abf39213fa0d2..9ebd0aa301592 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -76,13 +76,17 @@ object SparkHadoopWriter extends Logging { // Try to write all RDD partitions as a Hadoop OutputFormat. try { val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + // SPARK-24552: Generate a unique "attempt ID" based on the stage and task attempt numbers. + // Assumes that there won't be more than Short.MaxValue attempts, at least not concurrently. + val attemptId = (context.stageAttemptNumber << 16) | context.attemptNumber + executeTask( context = context, config = config, jobTrackerId = jobTrackerId, commitJobId = commitJobId, sparkPartitionId = context.partitionId, - sparkAttemptNumber = context.attemptNumber, + sparkAttemptNumber = attemptId, committer = committer, iterator = iter) }) diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 764735dc4eae7..db8aff94ea1e1 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -69,9 +69,9 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val taskAttemptNumber = TaskContext.get().attemptNumber() - val stageId = TaskContext.get().stageId() - val canCommit = outputCommitCoordinator.canCommit(stageId, splitId, taskAttemptNumber) + val ctx = TaskContext.get() + val canCommit = outputCommitCoordinator.canCommit(ctx.stageId(), ctx.stageAttemptNumber(), + splitId, ctx.attemptNumber()) if (canCommit) { performCommit() @@ -81,7 +81,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, stageId, splitId, taskAttemptNumber) + throw new CommitDeniedException(message, ctx.stageId(), splitId, ctx.attemptNumber()) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 13db4985b0b80..ba9dae4ad48ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -95,7 +95,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 4e036c2ed49b5..23cf19d55b4ae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -30,7 +30,7 @@ private[spark] class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { - @transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) + @transient lazy val _locations = BlockManager.blockIdsToLocations(blockIds, SparkEnv.get) @volatile private var _isValid = true override def getPartitions: Array[Partition] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index c1fedd63f6a90..e2b6df4600590 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -34,7 +34,11 @@ import org.apache.spark.util.Utils * Delivery will only begin when the `start()` method is called. The `stop()` method should be * called when no more events need to be delivered. */ -private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics) +private class AsyncEventQueue( + val name: String, + conf: SparkConf, + metrics: LiveListenerBusMetrics, + bus: LiveListenerBus) extends SparkListenerBus with Logging { @@ -81,23 +85,18 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi } private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) { - try { - var next: SparkListenerEvent = eventQueue.take() - while (next != POISON_PILL) { - val ctx = processingTime.time() - try { - super.postToAll(next) - } finally { - ctx.stop() - } - eventCount.decrementAndGet() - next = eventQueue.take() + var next: SparkListenerEvent = eventQueue.take() + while (next != POISON_PILL) { + val ctx = processingTime.time() + try { + super.postToAll(next) + } finally { + ctx.stop() } eventCount.decrementAndGet() - } catch { - case ie: InterruptedException => - logInfo(s"Stopping listener queue $name.", ie) + next = eventQueue.take() } + eventCount.decrementAndGet() } override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { @@ -130,7 +129,11 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi eventCount.incrementAndGet() eventQueue.put(POISON_PILL) } - dispatchThread.join() + // this thread might be trying to stop itself as part of error handling -- we can't join + // in that case. + if (Thread.currentThread() != dispatchThread) { + dispatchThread.join() + } } def post(event: SparkListenerEvent): Unit = { @@ -187,6 +190,12 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi true } + override def removeListenerOnError(listener: SparkListenerInterface): Unit = { + // the listener failed in an unrecoverably way, we want to remove it from the entire + // LiveListenerBus (potentially stopping a queue if it is empty) + bus.removeListener(listener) + } + } private object AsyncEventQueue { diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 30cf75d43ee09..980fbbe516b91 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -371,7 +371,7 @@ private[scheduler] class BlacklistTracker ( } -private[scheduler] object BlacklistTracker extends Logging { +private[spark] object BlacklistTracker extends Logging { private val DEFAULT_TIMEOUT = "1h" diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 78b6b34b5d2bb..f74425d73b392 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -206,7 +206,7 @@ class DAGScheduler( private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") - private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) + private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) /** @@ -1167,12 +1167,11 @@ class DAGScheduler( */ private[scheduler] def handleTaskCompletion(event: CompletionEvent) { val task = event.task - val taskId = event.taskInfo.id val stageId = task.stageId - val taskType = Utils.getFormattedClassName(task) outputCommitCoordinator.taskCompleted( stageId, + task.stageAttemptId, task.partitionId, event.taskInfo.attemptNumber, // this is a task attempt number event.reason) @@ -1210,7 +1209,7 @@ class DAGScheduler( case _ => updateAccumulators(event) } - case _: ExceptionFailure => updateAccumulators(event) + case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event) case _ => } postTaskEnd(event) @@ -1323,7 +1322,7 @@ class DAGScheduler( "tasks in ShuffleMapStages.") } - case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => + case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) @@ -1332,23 +1331,24 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { + failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is // possible the fetch failure has already been handled by the scheduler. if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some(failureMessage)) + markStageAsFinished(failedStage, errorMessage = Some(failureMessage), + willRetry = !shouldAbortStage) } else { logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " + s"longer running") } - failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) - val shouldAbortStage = - failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest - if (shouldAbortStage) { val abortMessage = if (disallowStageRetryForTest) { "Fetch failure will not retry stage due to testing config" @@ -1411,16 +1411,16 @@ class DAGScheduler( } } - case commitDenied: TaskCommitDenied => + case _: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits - case exceptionFailure: ExceptionFailure => + case _: ExceptionFailure | _: TaskKilled => // Nothing left to do, already handled above for accumulator updates. case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case _: ExecutorLostFailure | _: TaskKilled | UnknownReason => + case _: ExecutorLostFailure | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } @@ -1547,7 +1547,10 @@ class DAGScheduler( /** * Marks a stage as finished and removes it from the list of running stages. */ - private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + private def markStageAsFinished( + stage: Stage, + errorMessage: Option[String] = None, + willRetry: Boolean = false): Unit = { val serviceTime = stage.latestInfo.submissionTime match { case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) case _ => "Unknown" @@ -1566,7 +1569,9 @@ class DAGScheduler( logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}") } - outputCommitCoordinator.stageEnd(stage.id) + if (!willRetry) { + outputCommitCoordinator.stageEnd(stage.id) + } listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) runningStages -= stage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index ba6387a8f08ad..d135190d1e919 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -102,7 +102,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) { queue.addListener(listener) case None => - val newQueue = new AsyncEventQueue(queue, conf, metrics) + val newQueue = new AsyncEventQueue(queue, conf, metrics, this) newQueue.addListener(listener) if (started.get()) { newQueue.start(sparkContext) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 2ec2f2031aa45..659694dd189ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -50,7 +50,9 @@ private[spark] sealed trait MapStatus { private[spark] object MapStatus { def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { - if (uncompressedSizes.length > 2000) { + if (uncompressedSizes.length > Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) + .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { HighlyCompressedMapStatus(loc, uncompressedSizes) } else { new CompressedMapStatus(loc, uncompressedSizes) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 83d87b548a430..b382d623806e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -27,7 +27,11 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils} private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) +private case class AskPermissionToCommitOutput( + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -45,13 +49,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Initialized by SparkEnv var coordinatorRef: Option[RpcEndpointRef] = None - private type StageId = Int - private type PartitionId = Int - private type TaskAttemptNumber = Int - private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1 + // Class used to identify a committer. The task ID for a committer is implicitly defined by + // the partition being processed, but the coordinator needs to keep track of both the stage + // attempt and the task attempt, because in some situations the same task may be running + // concurrently in two different attempts of the same stage. + private case class TaskIdentifier(stageAttempt: Int, taskAttempt: Int) + private case class StageState(numPartitions: Int) { - val authorizedCommitters = Array.fill[TaskAttemptNumber](numPartitions)(NO_AUTHORIZED_COMMITTER) - val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]() + val authorizedCommitters = Array.fill[TaskIdentifier](numPartitions)(null) + val failures = mutable.Map[Int, mutable.Set[TaskIdentifier]]() } /** @@ -64,7 +70,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ - private val stageStates = mutable.Map[StageId, StageState]() + private val stageStates = mutable.Map[Int, StageState]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -87,10 +93,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * @return true if this task is authorized to commit, false otherwise */ def canCommit( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int): Boolean = { + val msg = AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => ThreadUtils.awaitResult(endpointRef.ask[Boolean](msg), @@ -103,26 +110,35 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) } /** - * Called by the DAGScheduler when a stage starts. + * Called by the DAGScheduler when a stage starts. Initializes the stage's state if it hasn't + * yet been initialized. * * @param stage the stage id. * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. * the maximum possible value of `context.partitionId`). */ - private[scheduler] def stageStart(stage: StageId, maxPartitionId: Int): Unit = synchronized { - stageStates(stage) = new StageState(maxPartitionId + 1) + private[scheduler] def stageStart(stage: Int, maxPartitionId: Int): Unit = synchronized { + stageStates.get(stage) match { + case Some(state) => + require(state.authorizedCommitters.length == maxPartitionId + 1) + logInfo(s"Reusing state from previous attempt of stage $stage.") + + case _ => + stageStates(stage) = new StageState(maxPartitionId + 1) + } } // Called by DAGScheduler - private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { + private[scheduler] def stageEnd(stage: Int): Unit = synchronized { stageStates.remove(stage) } // Called by DAGScheduler private[scheduler] def taskCompleted( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber, + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int, reason: TaskEndReason): Unit = synchronized { val stageState = stageStates.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -131,16 +147,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) reason match { case Success => // The task output has been committed successfully - case denied: TaskCommitDenied => - logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + - s"attempt: $attemptNumber") - case otherReason => + case _: TaskCommitDenied => + logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, " + + s"partition: $partition, attempt: $attemptNumber") + case _ => // Mark the attempt as failed to blacklist from future commit protocol - stageState.failures.getOrElseUpdate(partition, mutable.Set()) += attemptNumber - if (stageState.authorizedCommitters(partition) == attemptNumber) { + val taskId = TaskIdentifier(stageAttempt, attemptNumber) + stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskId + if (stageState.authorizedCommitters(partition) == taskId) { logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + s"partition=$partition) failed; clearing lock") - stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER + stageState.authorizedCommitters(partition) = null } } } @@ -155,47 +172,41 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Marked private[scheduler] instead of private so this can be mocked in tests private[scheduler] def handleAskPermissionToCommit( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber): Boolean = synchronized { + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int): Boolean = synchronized { stageStates.get(stage) match { - case Some(state) if attemptFailed(state, partition, attemptNumber) => - logInfo(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," + - s" partition=$partition as task attempt $attemptNumber has already failed.") + case Some(state) if attemptFailed(state, stageAttempt, partition, attemptNumber) => + logInfo(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + s"task attempt $attemptNumber already marked as failed.") false case Some(state) => - state.authorizedCommitters(partition) match { - case NO_AUTHORIZED_COMMITTER => - logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition") - state.authorizedCommitters(partition) = attemptNumber - true - case existingCommitter => - // Coordinator should be idempotent when receiving AskPermissionToCommit. - if (existingCommitter == attemptNumber) { - logWarning(s"Authorizing duplicate request to commit for " + - s"attemptNumber=$attemptNumber to commit for stage=$stage," + - s" partition=$partition; existingCommitter = $existingCommitter." + - s" This can indicate dropped network traffic.") - true - } else { - logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition; existingCommitter = $existingCommitter") - false - } + val existing = state.authorizedCommitters(partition) + if (existing == null) { + logDebug(s"Commit allowed for stage=$stage.$stageAttempt, partition=$partition, " + + s"task attempt $attemptNumber") + state.authorizedCommitters(partition) = TaskIdentifier(stageAttempt, attemptNumber) + true + } else { + logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + s"already committed by $existing") + false } case None => - logDebug(s"Stage $stage has completed, so not allowing" + - s" attempt number $attemptNumber of partition $partition to commit") + logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + "stage already marked as completed.") false } } private def attemptFailed( stageState: StageState, - partition: PartitionId, - attempt: TaskAttemptNumber): Boolean = synchronized { - stageState.failures.get(partition).exists(_.contains(attempt)) + stageAttempt: Int, + partition: Int, + attempt: Int): Boolean = synchronized { + val failInfo = TaskIdentifier(stageAttempt, attempt) + stageState.failures.get(partition).exists(_.contains(failInfo)) } } @@ -215,9 +226,10 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, attemptNumber) => + case AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition, + attemptNumber)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0c11806b3981b..598b62f85a1fa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} * up to launch speculative tasks, etc. * * Clients should first call initialize() and start(), then submit task sets through the - * runTasks method. + * submitTasks method. * * THREADING: [[SchedulerBackend]]s and task-submitting clients can call this class from multiple * threads, so it needs locks in public API methods to maintain its state. In addition, some @@ -62,7 +62,7 @@ private[spark] class TaskSchedulerImpl( this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) } - // Lazily initializing blackListTrackOpt to avoid getting empty ExecutorAllocationClient, + // Lazily initializing blacklistTrackerOpt to avoid getting empty ExecutorAllocationClient, // because ExecutorAllocationClient is created after this TaskSchedulerImpl. private[scheduler] lazy val blacklistTrackerOpt = maybeCreateBlacklistTracker(sc) @@ -228,7 +228,7 @@ private[spark] class TaskSchedulerImpl( // 1. The task set manager has been created and some tasks have been scheduled. // In this case, send a kill signal to the executors to kill the task and then abort // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // 2. The task set manager has been created but no tasks have been scheduled. In this case, // simply abort the stage. tsm.runningTasksSet.foreach { tid => taskIdToExecutorId.get(tid).foreach(execId => @@ -689,6 +689,20 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Marks the task has completed in all TaskSetManagers for the given stage. + * + * After stage failure and retry, there may be multiple TaskSetManagers for the stage. + * If an earlier attempt of a stage completes a task, we should ensure that the later attempts + * do not also submit those same tasks. That also means that a task completion from an earlier + * attempt can lead to the entire stage getting marked as successful. + */ + private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { + taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => + tsm.markPartitionCompleted(partitionId) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d958658527f6d..a18c66596852a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -73,6 +73,8 @@ private[spark] class TaskSetManager( val ser = env.closureSerializer.newInstance() val tasks = taskSet.tasks + private[scheduler] val partitionToIndex = tasks.zipWithIndex + .map { case (t, idx) => t.partitionId -> idx }.toMap val numTasks = tasks.length val copiesRunning = new Array[Int](numTasks) @@ -153,7 +155,7 @@ private[spark] class TaskSetManager( private[scheduler] val speculatableTasks = new HashSet[Int] // Task index, start and finish time for each task attempt (indexed by task ID) - private val taskInfos = new HashMap[Long, TaskInfo] + private[scheduler] val taskInfos = new HashMap[Long, TaskInfo] // Use a MedianHeap to record durations of successful tasks so we know when to launch // speculative tasks. This is only used when speculation is enabled, to avoid the overhead @@ -287,7 +289,7 @@ private[spark] class TaskSetManager( None } - /** Check whether a task is currently running an attempt on a given host */ + /** Check whether a task once ran an attempt on a given host */ private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { taskAttempts(taskIndex).exists(_.host == host) } @@ -754,6 +756,9 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } + // There may be multiple tasksets for this stage -- we let all of them know that the partition + // was completed. This may result in some of the tasksets getting completed. + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -764,6 +769,19 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } + private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { + partitionToIndex.get(partitionId).foreach { index => + if (!successful(index)) { + tasksSuccessful += 1 + successful(index) = true + if (tasksSuccessful == numTasks) { + isZombie = true + } + maybeFinishTaskSet() + } + } + } + /** * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the * DAG Scheduler. @@ -833,13 +851,19 @@ private[spark] class TaskSetManager( } ef.exception + case tk: TaskKilled => + // TaskKilled might have accumulator updates + accumUpdates = tk.accums + logWarning(failureReason) + None + case e: ExecutorLostFailure if !e.exitCausedByApp => logInfo(s"Task $tid failed because while it was being computed, its executor " + "exited for a reason unrelated to the task. Not counting this failure towards the " + "maximum number of failures for the task.") None - case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others + case e: TaskFailedReason => // TaskResultLost and others logWarning(failureReason) None } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5627a557a12f3..9b90e309d2e04 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -170,8 +170,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (executorDataMap.contains(executorId)) { executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) context.reply(true) - } else if (scheduler.nodeBlacklist != null && - scheduler.nodeBlacklist.contains(hostname)) { + } else if (scheduler.nodeBlacklist.contains(hostname)) { // If the cluster manager gives us an executor on a blacklisted node (because it // already started allocating those resources before we informed it of our blacklist, // or if it ignored our blacklist), then we reject that executor immediately. @@ -633,7 +632,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } doRequestTotalExecutors(requestedTotalExecutors) } else { - numPendingExecutors += knownExecutors.size + numPendingExecutors += executorsToKill.size Future.successful(true) } diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala new file mode 100644 index 0000000000000..d15e7937b0523 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.security + +import java.io.{DataInputStream, DataOutputStream, InputStream} +import java.net.Socket +import java.nio.charset.StandardCharsets.UTF_8 + +import org.apache.spark.SparkConf +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils + +/** + * A class that can be used to add a simple authentication protocol to socket-based communication. + * + * The protocol is simple: an auth secret is written to the socket, and the other side checks the + * secret and writes either "ok" or "err" to the output. If authentication fails, the socket is + * not expected to be valid anymore. + * + * There's no secrecy, so this relies on the sockets being either local or somehow encrypted. + */ +private[spark] class SocketAuthHelper(conf: SparkConf) { + + val secret = Utils.createSecret(conf) + + /** + * Read the auth secret from the socket and compare to the expected value. Write the reply back + * to the socket. + * + * If authentication fails, this method will close the socket. + * + * @param s The client socket. + * @throws IllegalArgumentException If authentication fails. + */ + def authClient(s: Socket): Unit = { + // Set the socket timeout while checking the auth secret. Reset it before returning. + val currentTimeout = s.getSoTimeout() + try { + s.setSoTimeout(10000) + val clientSecret = readUtf8(s) + if (secret == clientSecret) { + writeUtf8("ok", s) + } else { + writeUtf8("err", s) + JavaUtils.closeQuietly(s) + } + } finally { + s.setSoTimeout(currentTimeout) + } + } + + /** + * Authenticate with a server by writing the auth secret and checking the server's reply. + * + * If authentication fails, this method will close the socket. + * + * @param s The socket connected to the server. + * @throws IllegalArgumentException If authentication fails. + */ + def authToServer(s: Socket): Unit = { + writeUtf8(secret, s) + + val reply = readUtf8(s) + if (reply != "ok") { + JavaUtils.closeQuietly(s) + throw new IllegalArgumentException("Authentication failed.") + } + } + + protected def readUtf8(s: Socket): String = { + val din = new DataInputStream(s.getInputStream()) + val len = din.readInt() + val bytes = new Array[Byte](len) + din.readFully(bytes) + new String(bytes, UTF_8) + } + + protected def writeUtf8(str: String, s: Socket): Unit = { + val bytes = str.getBytes(UTF_8) + val dout = new DataOutputStream(s.getOutputStream()) + dout.writeInt(bytes.length) + dout.write(bytes, 0, bytes.length) + dout.flush() + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 688f25a9fdea1..e237281c552b1 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -471,7 +471,7 @@ private[spark] class AppStatusStore( def operationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = { val job = store.read(classOf[JobDataWrapper], jobId) - val stages = job.info.stageIds + val stages = job.info.stageIds.sorted stages.map { id => val g = store.read(classOf[RDDOperationGraphWrapper], id).toRDDOperationGraph() diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 7127397f6205c..84c2ad48f1f27 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -28,7 +28,7 @@ import org.glassfish.jersey.server.ServerProperties import org.glassfish.jersey.servlet.ServletContainer import org.apache.spark.SecurityManager -import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.{SparkUI, UIUtils} /** * Main entry point for serving spark application metrics as json, using JAX-RS. @@ -49,6 +49,7 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}") def application(): Class[OneApplicationResource] = classOf[OneApplicationResource] + @GET @Path("version") def version(): VersionInfo = new VersionInfo(org.apache.spark.SPARK_VERSION) @@ -147,38 +148,18 @@ private[v1] trait BaseAppResource extends ApiRequestContext { } private[v1] class ForbiddenException(msg: String) extends WebApplicationException( - Response.status(Response.Status.FORBIDDEN).entity(msg).build()) + UIUtils.buildErrorResponse(Response.Status.FORBIDDEN, msg)) private[v1] class NotFoundException(msg: String) extends WebApplicationException( - new NoSuchElementException(msg), - Response - .status(Response.Status.NOT_FOUND) - .entity(ErrorWrapper(msg)) - .build() -) + UIUtils.buildErrorResponse(Response.Status.NOT_FOUND, msg)) private[v1] class ServiceUnavailable(msg: String) extends WebApplicationException( - new ServiceUnavailableException(msg), - Response - .status(Response.Status.SERVICE_UNAVAILABLE) - .entity(ErrorWrapper(msg)) - .build() -) + UIUtils.buildErrorResponse(Response.Status.SERVICE_UNAVAILABLE, msg)) private[v1] class BadParameterException(msg: String) extends WebApplicationException( - new IllegalArgumentException(msg), - Response - .status(Response.Status.BAD_REQUEST) - .entity(ErrorWrapper(msg)) - .build() -) { + UIUtils.buildErrorResponse(Response.Status.BAD_REQUEST, msg)) { def this(param: String, exp: String, actual: String) = { this(raw"""Bad value for parameter "$param". Expected a $exp, got "$actual"""") } } -/** - * Signal to JacksonMessageWriter to not convert the message into json (which would result in an - * extra set of quotes). - */ -private[v1] case class ErrorWrapper(s: String) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index 76af33c1a18db..4560d300cb0c8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -68,10 +68,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ mediaType: MediaType, multivaluedMap: MultivaluedMap[String, AnyRef], outputStream: OutputStream): Unit = { - t match { - case ErrorWrapper(err) => outputStream.write(err.getBytes(StandardCharsets.UTF_8)) - case _ => mapper.writeValue(outputStream, t) - } + mapper.writeValue(outputStream, t) } override def getSize( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index 974697890dd03..32100c5704538 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -140,11 +140,8 @@ private[v1] class AbstractApplicationResource extends BaseAppResource { .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) .build() } catch { - case NonFatal(e) => - Response.serverError() - .entity(s"Event logs are not available for app: $appId.") - .status(Response.Status.SERVICE_UNAVAILABLE) - .build() + case NonFatal(_) => + throw new ServiceUnavailable(s"Event logs are not available for app: $appId.") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc4224..0e1c7d5fd3fa2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -45,6 +45,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv +import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ @@ -291,7 +292,7 @@ private[spark] class BlockManager( case e: Exception if i < MAX_ATTEMPTS => logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) - Thread.sleep(SLEEP_TIME_SECS * 1000) + Thread.sleep(SLEEP_TIME_SECS * 1000L) case NonFatal(e) => throw new SparkException("Unable to register with external shuffle server due to : " + e.getMessage, e) @@ -1554,7 +1555,7 @@ private[spark] class BlockManager( private[spark] object BlockManager { private val ID_GENERATOR = new IdGenerator - def blockIdsToHosts( + def blockIdsToLocations( blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = { @@ -1569,7 +1570,9 @@ private[spark] object BlockManager { val blockManagers = new HashMap[BlockId, Seq[String]] for (i <- 0 until blockIds.length) { - blockManagers(blockIds(i)) = blockLocations(i).map(_.host) + blockManagers(blockIds(i)) = blockLocations(i).map { loc => + ExecutorCacheTaskLocation(loc.host, loc.executorId).toString + } } blockManagers.toMap } diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index e5abbf745cc41..9ccc8f9cc585b 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -17,7 +17,9 @@ package org.apache.spark.storage +import org.apache.spark.SparkEnv import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.util.Utils @@ -53,10 +55,16 @@ class RDDInfo( } private[spark] object RDDInfo { + private val callsiteForm = SparkEnv.get.conf.get(EVENT_LOG_CALLSITE_FORM) + def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) + val callSite = callsiteForm match { + case "short" => rdd.creationSite.shortForm + case "long" => rdd.creationSite.longForm + } new RDDInfo(rdd.id, rddName, rdd.partitions.length, - rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) + rdd.getStorageLevel, parentIds, callSite, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index dd9df74689a13..b31862323a895 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -48,7 +48,9 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in - * order to throttle the memory usage. + * order to throttle the memory usage. Note that zero-sized blocks are + * already excluded, which happened in + * [[MapOutputTracker.convertMapStatuses]]. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -62,7 +64,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: ShuffleClient, blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, @@ -74,8 +76,8 @@ final class ShuffleBlockFetcherIterator( import ShuffleBlockFetcherIterator._ /** - * Total number of blocks to fetch. This can be smaller than the total number of blocks - * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. + * Total number of blocks to fetch. This should be equal to the total number of blocks + * in [[blocksByAddress]] because we already filter out zero-sized blocks in [[blocksByAddress]]. * * This should equal localBlocks.size + remoteBlocks.size. */ @@ -267,13 +269,16 @@ final class ShuffleBlockFetcherIterator( // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] - // Tracks total number of blocks (including zero sized blocks) - var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { - totalBlocks += blockInfos.size if (address.executorId == blockManager.blockManagerId.executorId) { - // Filter out zero-sized blocks - localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + blockInfos.find(_._2 <= 0) match { + case Some((blockId, size)) if size < 0 => + throw new BlockException(blockId, "Negative block size " + size) + case Some((blockId, size)) if size == 0 => + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + case None => // do nothing. + } + localBlocks ++= blockInfos.map(_._1) numBlocksToFetch += localBlocks.size } else { val iterator = blockInfos.iterator @@ -281,14 +286,15 @@ final class ShuffleBlockFetcherIterator( var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { + if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } else if (size == 0) { + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + } else { curBlocks += ((blockId, size)) remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= targetRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { @@ -306,7 +312,8 @@ final class ShuffleBlockFetcherIterator( } } } - logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") + logInfo(s"Getting $numBlocksToFetch non-empty blocks including ${localBlocks.size}" + + s" local blocks and ${remoteBlocks.size} remote blocks") remoteRequests } @@ -407,6 +414,25 @@ final class ShuffleBlockFetcherIterator( logDebug("Number of requests in flight " + reqsInFlight) } + if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + throwFetchFailedException(blockId, address, new IOException(msg)) + } + val in = try { buf.createInputStream() } catch { diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 0e8a6307de6a8..52a955111231a 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -263,7 +263,7 @@ private[spark] object JettyUtils extends Logging { filters.foreach { case filter : String => if (!filter.isEmpty) { - logInfo("Adding filter: " + filter) + logInfo(s"Adding filter $filter to ${handlers.map(_.getContextPath).mkString(", ")}.") val holder : FilterHolder = new FilterHolder() holder.setClassName(filter) // Get any parameters for each filter @@ -344,6 +344,7 @@ private[spark] object JettyUtils extends Logging { connectionFactories: _*) connector.setPort(port) connector.setHost(hostName) + connector.setReuseAddress(!Utils.isWindows) // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads @@ -406,7 +407,7 @@ private[spark] object JettyUtils extends Logging { } pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) - ServerInfo(server, httpPort, securePort, collection) + ServerInfo(server, httpPort, securePort, conf, collection) } catch { case e: Exception => server.stop() @@ -506,10 +507,12 @@ private[spark] case class ServerInfo( server: Server, boundPort: Int, securePort: Option[Int], + conf: SparkConf, private val rootHandler: ContextHandlerCollection) { - def addHandler(handler: ContextHandler): Unit = { + def addHandler(handler: ServletContextHandler): Unit = { handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME)) + JettyUtils.addFilters(Seq(handler), conf) rootHandler.addHandler(handler) if (!handler.isStarted()) { handler.start() diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index b44ac0ea1febc..d315ef66e0dc0 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -65,7 +65,7 @@ private[spark] class SparkUI private ( attachTab(new StorageTab(this, store)) attachTab(new EnvironmentTab(this, store)) attachTab(new ExecutorsTab(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 02cf19e00ecde..732b7528f499e 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -20,6 +20,8 @@ package org.apache.spark.ui import java.net.URLDecoder import java.text.SimpleDateFormat import java.util.{Date, Locale, TimeZone} +import javax.servlet.http.HttpServletRequest +import javax.ws.rs.core.{MediaType, Response} import scala.util.control.NonFatal import scala.xml._ @@ -148,60 +150,71 @@ private[spark] object UIUtils extends Logging { } // Yarn has to go through a proxy so the base uri is provided and has to be on all links - def uiRoot: String = { + def uiRoot(request: HttpServletRequest): String = { + // Knox uses X-Forwarded-Context to notify the application the base path + val knoxBasePath = Option(request.getHeader("X-Forwarded-Context")) // SPARK-11484 - Use the proxyBase set by the AM, if not found then use env. sys.props.get("spark.ui.proxyBase") .orElse(sys.env.get("APPLICATION_WEB_PROXY_BASE")) + .orElse(knoxBasePath) .getOrElse("") } - def prependBaseUri(basePath: String = "", resource: String = ""): String = { - uiRoot + basePath + resource + def prependBaseUri( + request: HttpServletRequest, + basePath: String = "", + resource: String = ""): String = { + uiRoot(request) + basePath + resource } - def commonHeaderNodes: Seq[Node] = { + def commonHeaderNodes(request: HttpServletRequest): Seq[Node] = { - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + } - def vizHeaderNodes: Seq[Node] = { - - - - - + def vizHeaderNodes(request: HttpServletRequest): Seq[Node] = { + + + + + } - def dataTablesHeaderNodes: Seq[Node] = { + def dataTablesHeaderNodes(request: HttpServletRequest): Seq[Node] = { + + href={prependBaseUri(request, "/static/dataTables.bootstrap.css")} type="text/css"/> - - - - - - - + href={prependBaseUri(request, "/static/jsonFormatter.min.css")} type="text/css"/> + + + + + + } /** Returns a spark page with correctly formatted headers */ def headerSparkPage( + request: HttpServletRequest, title: String, content: => Seq[Node], activeTab: SparkUITab, @@ -214,25 +227,26 @@ private[spark] object UIUtils extends Logging { val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • - {tab.name} + {tab.name}
  • } val helpButton: Seq[Node] = helpText.map(tooltip(_, "bottom")).getOrElse(Seq.empty) - {commonHeaderNodes} - {if (showVisualization) vizHeaderNodes else Seq.empty} - {if (useDataTables) dataTablesHeaderNodes else Seq.empty} - + {commonHeaderNodes(request)} + {if (showVisualization) vizHeaderNodes(request) else Seq.empty} + {if (useDataTables) dataTablesHeaderNodes(request) else Seq.empty} + {appName} - {title} } - UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) + UIUtils.headerSparkPage( + request, s"Details for Job $jobId", content, parent, showVisualization = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index a3e1f13782e30..22a40101e33df 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -49,7 +49,7 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { "stages/pool", parent.isFairScheduler, parent.killEnabled, false) val poolTable = new PoolTable(Map(pool -> uiPool), parent) - var content =

    Summary

    ++ poolTable.toNodeSeq + var content =

    Summary

    ++ poolTable.toNodeSeq(request) if (activeStages.nonEmpty) { content ++= } - UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) + UIUtils.headerSparkPage(request, "Fair Scheduler Pool: " + poolName, content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 5dfce858dec07..96b5f72393070 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder +import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -28,7 +29,7 @@ import org.apache.spark.ui.UIUtils /** Table showing list of pools */ private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab) { - def toNodeSeq: Seq[Node] = { + def toNodeSeq(request: HttpServletRequest): Seq[Node] = { @@ -39,15 +40,15 @@ private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab - {pools.map { case (s, p) => poolRow(s, p) }} + {pools.map { case (s, p) => poolRow(request, s, p) }}
    Pool NameSchedulingMode
    } - private def poolRow(s: Schedulable, p: PoolData): Seq[Node] = { + private def poolRow(request: HttpServletRequest, s: Schedulable, p: PoolData): Seq[Node] = { val activeStages = p.stageIds.size val href = "%s/stages/pool?poolname=%s" - .format(UIUtils.prependBaseUri(parent.basePath), URLEncoder.encode(p.name, "UTF-8")) + .format(UIUtils.prependBaseUri(request, parent.basePath), URLEncoder.encode(p.name, "UTF-8")) {p.name} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ac83de10f9237..55eb989962668 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -112,20 +112,19 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We

    No information to display for Stage {stageId} (Attempt {stageAttemptId})

    - return UIUtils.headerSparkPage(stageHeader, content, parent) + return UIUtils.headerSparkPage(request, stageHeader, content, parent) } val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) - val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + - stageData.numFailedTasks + stageData.numKilledTasks + val totalTasks = taskCount(stageData) if (totalTasks == 0) { val content =

    Summary Metrics

    No tasks have started yet

    Tasks

    No tasks have started yet
    - return UIUtils.headerSparkPage(stageHeader, content, parent) + return UIUtils.headerSparkPage(request, stageHeader, content, parent) } val storedTasks = store.taskCount(stageData.stageId, stageData.attemptId) @@ -133,7 +132,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val totalTasksNumStr = if (totalTasks == storedTasks) { s"$totalTasks" } else { - s"$totalTasks, showing ${storedTasks}" + s"$storedTasks, showing ${totalTasks}" } val summary = @@ -282,8 +281,8 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( stageData, - UIUtils.prependBaseUri(parent.basePath) + - s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", + UIUtils.prependBaseUri(request, parent.basePath) + + s"/stages/stage/?id=${stageId}&attempt=${stageAttemptId}", currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, @@ -498,7 +497,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
    {taskTableHTML ++ jsForScrollingDownToTaskTable}
    - UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) + UIUtils.headerSparkPage(request, stageHeader, content, parent, showVisualization = true) } def makeTimeline(tasks: Seq[TaskData], currentTime: Long): Seq[Node] = { @@ -686,7 +685,7 @@ private[ui] class TaskDataSource( private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = stage.numTasks + override def dataSize: Int = taskCount(stage) override def sliceData(from: Int, to: Int): Seq[TaskData] = { if (_tasksToShow == null) { @@ -1052,4 +1051,9 @@ private[ui] object ApiHelper { (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) } + def taskCount(stageData: StageData): Int = { + stageData.numActiveTasks + stageData.numCompleteTasks + stageData.numFailedTasks + + stageData.numKilledTasks + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 18a4926f2f6c0..d01acdae59c9f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -43,7 +43,9 @@ private[ui] class StageTableBase( killEnabled: Boolean, isFailedStage: Boolean) { // stripXSS is called to remove suspicious characters used in XSS attacks - val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) + val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => + UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq + } val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag)) .map(para => para._1 + "=" + para._2(0)) @@ -92,7 +94,8 @@ private[ui] class StageTableBase( stageSortColumn, stageSortDesc, isFailedStage, - parameterOtherTable + parameterOtherTable, + request ).table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => @@ -147,7 +150,8 @@ private[ui] class StagePagedTable( sortColumn: String, desc: Boolean, isFailedStage: Boolean, - parameterOtherTable: Iterable[String]) extends PagedTable[StageTableRowData] { + parameterOtherTable: Iterable[String], + request: HttpServletRequest) extends PagedTable[StageTableRowData] { override def tableId: String = stageTag + "-table" @@ -161,7 +165,7 @@ private[ui] class StagePagedTable( override def pageNumberFormField: String = stageTag + ".page" - val parameterPath = UIUtils.prependBaseUri(basePath) + s"/$subPath/?" + + val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + parameterOtherTable.mkString("&") override val dataSource = new StageDataSource( @@ -288,7 +292,7 @@ private[ui] class StagePagedTable( {if (isFairScheduler) { + .format(UIUtils.prependBaseUri(request, basePath), data.schedulingPool)}> {data.schedulingPool} @@ -346,7 +350,7 @@ private[ui] class StagePagedTable( } private def makeDescription(s: v1.StageData, descriptionOption: Option[String]): Seq[Node] = { - val basePathUri = UIUtils.prependBaseUri(basePath) + val basePathUri = UIUtils.prependBaseUri(request, basePath) val killLink = if (killEnabled) { val confirm = @@ -366,7 +370,7 @@ private[ui] class StagePagedTable( Seq.empty } - val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" + val nameLinkUri = s"$basePathUri/stages/stage/?id=${s.stageId}&attempt=${s.attemptId}" val nameLink = {s.name} val cachedRddInfos = store.rddList().filter { rdd => s.rddIds.contains(rdd.id) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 2674b9291203a..238cd31433660 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -53,7 +53,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web } catch { case _: NoSuchElementException => // Rather than crashing, render an "RDD Not Found" page - return UIUtils.headerSparkPage("RDD Not Found", Seq.empty[Node], parent) + return UIUtils.headerSparkPage(request, "RDD Not Found", Seq.empty[Node], parent) } // Worker table @@ -72,7 +72,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web } val blockTableHTML = try { val _blockTable = new BlockPagedTable( - UIUtils.prependBaseUri(parent.basePath) + s"/storage/rdd/?id=${rddId}", + UIUtils.prependBaseUri(request, parent.basePath) + s"/storage/rdd/?id=${rddId}", rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, @@ -145,7 +145,8 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web {blockTableHTML ++ jsForScrollingDownToBlockTable} ; - UIUtils.headerSparkPage("RDD Storage Info for " + rddStorageInfo.name, content, parent) + UIUtils.headerSparkPage( + request, "RDD Storage Info for " + rddStorageInfo.name, content, parent) } /** Header fields for the worker table */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 68d946574a37b..3eb546e336e99 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -31,11 +31,14 @@ import org.apache.spark.util.Utils private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { - val content = rddTable(store.rddList()) ++ receiverBlockTables(store.streamBlocksList()) - UIUtils.headerSparkPage("Storage", content, parent) + val content = rddTable(request, store.rddList()) ++ + receiverBlockTables(store.streamBlocksList()) + UIUtils.headerSparkPage(request, "Storage", content, parent) } - private[storage] def rddTable(rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { + private[storage] def rddTable( + request: HttpServletRequest, + rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { if (rdds.isEmpty) { // Don't show the rdd table if there is no RDD persisted. Nil @@ -49,7 +52,11 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends
    - {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} + {UIUtils.listingTable( + rddHeader, + rddRow(request, _: v1.RDDStorageInfo), + rdds, + id = Some("storage-by-rdd-table"))}
    } @@ -66,12 +73,13 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends "Size on Disk") /** Render an HTML row representing an RDD */ - private def rddRow(rdd: v1.RDDStorageInfo): Seq[Node] = { + private def rddRow(request: HttpServletRequest, rdd: v1.RDDStorageInfo): Seq[Node] = { // scalastyle:off {rdd.id} - + {rdd.name} diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 0f84ea9752cf5..bf618b4afbce0 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext} +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.AccumulableInfo private[spark] case class AccumulatorMetadata( @@ -199,10 +200,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } override def toString: String = { + // getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead if (metadata == null) { - "Un-registered Accumulator: " + getClass.getSimpleName + "Un-registered Accumulator: " + Utils.getSimpleName(getClass) } else { - getClass.getSimpleName + s"(id: $id, name: $name, value: $value)" + Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)" } } } @@ -211,7 +214,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { /** * An internal class used to track accumulators by Spark itself. */ -private[spark] object AccumulatorContext { +private[spark] object AccumulatorContext extends Logging { /** * This global map holds the original accumulator objects that are created on the driver. @@ -258,13 +261,16 @@ private[spark] object AccumulatorContext { * Returns the [[AccumulatorV2]] registered with the given ID, if any. */ def get(id: Long): Option[AccumulatorV2[_, _]] = { - Option(originals.get(id)).map { ref => - // Since we are storing weak references, we must check whether the underlying data is valid. + val ref = originals.get(id) + if (ref eq null) { + None + } else { + // Since we are storing weak references, warn when the underlying data is not valid. val acc = ref.get if (acc eq null) { - throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id") + logWarning(s"Attempted to access garbage collected accumulator $id") } - acc + Option(acc) } } @@ -486,7 +492,9 @@ class LegacyAccumulatorWrapper[R, T]( param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] { private[spark] var _value = initialValue // Current value on driver - override def isZero: Boolean = _value == param.zero(initialValue) + @transient private lazy val _zero = param.zero(initialValue) + + override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef]) override def copy(): LegacyAccumulatorWrapper[R, T] = { val acc = new LegacyAccumulatorWrapper(initialValue, param) @@ -495,7 +503,7 @@ class LegacyAccumulatorWrapper[R, T]( } override def reset(): Unit = { - _value = param.zero(initialValue) + _value = _zero } override def add(v: T): Unit = _value = param.addAccumulator(_value, v) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index ad0c0639521f6..073d71c63b0c7 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.{Map, Set, Stack} import scala.language.existentials -import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.internal.Logging diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index 3ea9139e11027..651ea4996f6cb 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -37,7 +37,8 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { private val stopped = new AtomicBoolean(false) - private val eventThread = new Thread(name) { + // Exposed for testing. + private[spark] val eventThread = new Thread(name) { setDaemon(true) override def run(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 40383fe05026b..50c6461373dee 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -407,7 +407,9 @@ private[spark] object JsonProtocol { ("Exit Caused By App" -> exitCausedByApp) ~ ("Loss Reason" -> reason.map(_.toString)) case taskKilled: TaskKilled => - ("Kill Reason" -> taskKilled.reason) + val accumUpdates = JArray(taskKilled.accumUpdates.map(accumulableInfoToJson).toList) + ("Kill Reason" -> taskKilled.reason) ~ + ("Accumulator Updates" -> accumUpdates) case _ => emptyJson } ("Reason" -> reason) ~ json @@ -917,7 +919,10 @@ private[spark] object JsonProtocol { case `taskKilled` => val killReason = jsonOption(json \ "Kill Reason") .map(_.extract[String]).getOrElse("unknown reason") - TaskKilled(killReason) + val accumUpdates = jsonOption(json \ "Accumulator Updates") + .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) + .getOrElse(Seq[AccumulableInfo]()) + TaskKilled(killReason, accumUpdates) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index b25a731401f23..d4474a90b26f1 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -60,6 +60,15 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } } + /** + * This can be overriden by subclasses if there is any extra cleanup to do when removing a + * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus. + */ + def removeListenerOnError(listener: L): Unit = { + removeListener(listener) + } + + /** * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. @@ -80,7 +89,16 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } try { doPostEvent(listener, event) + if (Thread.interrupted()) { + // We want to throw the InterruptedException right away so we can associate the interrupt + // with this listener, as opposed to waiting for a queue.take() etc. to detect it. + throw new InterruptedException() + } } catch { + case ie: InterruptedException => + logError(s"Interrupted while posting to ${Utils.getFormattedClassName(listener)}. " + + s"Removing that listener.", ie) + removeListenerOnError(listener) case NonFatal(e) if !isIgnorableException(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) } finally { diff --git a/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala b/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala new file mode 100644 index 0000000000000..1aa2009fa9b5b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +/** + * SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we catch + * fatal throwable in {@link scala.concurrent.Future}'s body, and re-throw + * SparkFatalException, which wraps the fatal throwable inside. + * Note that SparkFatalException should only be thrown from a {@link scala.concurrent.Future}, + * which is run by using ThreadUtils.awaitResult. ThreadUtils.awaitResult will catch + * it and re-throw the original exception/error. + */ +private[spark] final class SparkFatalException(val throwable: Throwable) extends Exception diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index e0f5af5250e7f..1b34fbde38cd6 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -39,10 +39,15 @@ private[spark] class SparkUncaughtExceptionHandler(val exitOnUncaughtException: // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) if (!ShutdownHookManager.inShutdown()) { - if (exception.isInstanceOf[OutOfMemoryError]) { - System.exit(SparkExitCode.OOM) - } else if (exitOnUncaughtException) { - System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + exception match { + case _: OutOfMemoryError => + System.exit(SparkExitCode.OOM) + case e: SparkFatalException if e.throwable.isInstanceOf[OutOfMemoryError] => + // SPARK-24294: This is defensive code, in case that SparkFatalException is + // misused and uncaught. + System.exit(SparkExitCode.OOM) + case _ if exitOnUncaughtException => + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 81aaf79db0c13..0f08a2b0ad895 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,13 +19,12 @@ package org.apache.spark.util import java.util.concurrent._ +import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} -import scala.concurrent.duration.Duration +import scala.concurrent.duration.{Duration, FiniteDuration} import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal -import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} - import org.apache.spark.SparkException private[spark] object ThreadUtils { @@ -103,6 +102,22 @@ private[spark] object ThreadUtils { executor } + /** + * Wrapper over ScheduledThreadPoolExecutor. + */ + def newDaemonThreadPoolScheduledExecutor(threadNamePrefix: String, numThreads: Int) + : ScheduledExecutorService = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(s"$threadNamePrefix-%d") + .build() + val executor = new ScheduledThreadPoolExecutor(numThreads, threadFactory) + // By default, a cancelled task is not automatically removed from the work queue until its delay + // elapses. We have to enable it manually. + executor.setRemoveOnCancelPolicy(true) + executor + } + /** * Run a piece of code in a new thread and return the result. Exception in the new thread is * thrown in the caller thread with an adjusted stack trace that removes references to this @@ -200,6 +215,8 @@ private[spark] object ThreadUtils { val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] awaitable.result(atMost)(awaitPermission) } catch { + case e: SparkFatalException => + throw e.throwable // TimeoutException is thrown in the current thread, so not need to warp the exception. case NonFatal(t) if !t.isInstanceOf[TimeoutException] => throw new SparkException("Exception thrown in awaitResult: ", t) @@ -227,4 +244,14 @@ private[spark] object ThreadUtils { } } // scalastyle:on awaitready + + def shutdown( + executor: ExecutorService, + gracePeriod: Duration = FiniteDuration(30, TimeUnit.SECONDS)): Unit = { + executor.shutdown() + executor.awaitTermination(gracePeriod.toMillis, TimeUnit.MILLISECONDS) + if (!executor.isShutdown) { + executor.shutdownNow() + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d2be93226e2a2..a6fd3637663e8 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,6 +18,8 @@ package org.apache.spark.util import java.io._ +import java.lang.{Byte => JByte} +import java.lang.InternalError import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} @@ -26,11 +28,12 @@ import java.nio.ByteBuffer import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets import java.nio.file.Files +import java.security.SecureRandom import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ +import java.util.concurrent.TimeUnit.NANOSECONDS import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.GZIPInputStream -import javax.net.ssl.HttpsURLConnection import scala.annotation.tailrec import scala.collection.JavaConverters._ @@ -44,6 +47,7 @@ import scala.util.matching.Regex import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import com.google.common.hash.HashCodes import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses import org.apache.commons.lang3.SystemUtils @@ -96,7 +100,7 @@ private[spark] object Utils extends Logging { */ val DEFAULT_MAX_TO_STRING_FIELDS = 25 - private def maxNumToStringFields = { + private[spark] def maxNumToStringFields = { if (SparkEnv.get != null) { SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS) } else { @@ -431,7 +435,7 @@ private[spark] object Utils extends Logging { new URI("file:///" + rawFileName).getPath.substring(1) } - /** + /** * Download a file or directory to target directory. Supports fetching the file in a variety of * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based * on the URL parameter. Fetching directories is only supported from Hadoop-compatible @@ -504,6 +508,14 @@ private[spark] object Utils extends Logging { targetFile } + /** Records the duration of running `body`. */ + def timeTakenMs[T](body: => T): (T, Long) = { + val startTime = System.nanoTime() + val result = body + val endTime = System.nanoTime() + (result, math.max(NANOSECONDS.toMillis(endTime - startTime), 0)) + } + /** * Download `in` to `tempFile`, then move it to `destFile`. * @@ -808,15 +820,15 @@ private[spark] object Utils extends Logging { conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) { conf.getenv("SPARK_LOCAL_DIRS").split(",") - } else if (conf.getenv("MESOS_DIRECTORY") != null && !shuffleServiceEnabled) { + } else if (conf.getenv("MESOS_SANDBOX") != null && !shuffleServiceEnabled) { // Mesos already creates a directory per Mesos task. Spark should use that directory // instead so all temporary files are automatically cleaned up when the Mesos task ends. // Note that we don't want this if the shuffle service is enabled because we want to // continue to serve shuffle files after the executors that wrote them have already exited. - Array(conf.getenv("MESOS_DIRECTORY")) + Array(conf.getenv("MESOS_SANDBOX")) } else { - if (conf.getenv("MESOS_DIRECTORY") != null && shuffleServiceEnabled) { - logInfo("MESOS_DIRECTORY available but not using provided Mesos sandbox because " + + if (conf.getenv("MESOS_SANDBOX") != null && shuffleServiceEnabled) { + logInfo("MESOS_SANDBOX available but not using provided Mesos sandbox because " + "spark.shuffle.service.enabled is enabled.") } // In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user @@ -1818,7 +1830,7 @@ private[spark] object Utils extends Logging { /** Return the class name of the given object, removing all dollar signs */ def getFormattedClassName(obj: AnyRef): String = { - obj.getClass.getSimpleName.replace("$", "") + getSimpleName(obj.getClass).replace("$", "") } /** @@ -2689,6 +2701,86 @@ private[spark] object Utils extends Logging { s"k8s://$resolvedURL" } + + /** + * Replaces all the {{EXECUTOR_ID}} occurrences with the Executor Id + * and {{APP_ID}} occurrences with the App Id. + */ + def substituteAppNExecIds(opt: String, appId: String, execId: String): String = { + opt.replace("{{APP_ID}}", appId).replace("{{EXECUTOR_ID}}", execId) + } + + /** + * Replaces all the {{APP_ID}} occurrences with the App Id. + */ + def substituteAppId(opt: String, appId: String): String = { + opt.replace("{{APP_ID}}", appId) + } + + def createSecret(conf: SparkConf): String = { + val bits = conf.get(AUTH_SECRET_BIT_LENGTH) + val rnd = new SecureRandom() + val secretBytes = new Array[Byte](bits / JByte.SIZE) + rnd.nextBytes(secretBytes) + HashCodes.fromBytes(secretBytes).toString() + } + + /** + * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala. + * This method mimicks scalatest's getSimpleNameOfAnObjectsClass. + */ + def getSimpleName(cls: Class[_]): String = { + try { + return cls.getSimpleName + } catch { + case err: InternalError => return stripDollars(stripPackages(cls.getName)) + } + } + + /** + * Remove the packages from full qualified class name + */ + private def stripPackages(fullyQualifiedName: String): String = { + fullyQualifiedName.split("\\.").takeRight(1)(0) + } + + /** + * Remove trailing dollar signs from qualified class name, + * and return the trailing part after the last dollar sign in the middle + */ + private def stripDollars(s: String): String = { + val lastDollarIndex = s.lastIndexOf('$') + if (lastDollarIndex < s.length - 1) { + // The last char is not a dollar sign + if (lastDollarIndex == -1 || !s.contains("$iw")) { + // The name does not have dollar sign or is not an intepreter + // generated class, so we should return the full string + s + } else { + // The class name is intepreter generated, + // return the part after the last dollar sign + // This is the same behavior as getClass.getSimpleName + s.substring(lastDollarIndex + 1) + } + } + else { + // The last char is a dollar sign + // Find last non-dollar char + val lastNonDollarChar = s.reverse.find(_ != '$') + lastNonDollarChar match { + case None => s + case Some(c) => + val lastNonDollarIndex = s.lastIndexOf(c) + if (lastNonDollarIndex == -1) { + s + } else { + // Strip the trailing dollar signs + // Invoke stripDollars again to get the simple name + stripDollars(s.substring(0, lastNonDollarIndex + 1)) + } + } + } + } } private[util] object CallerContext extends Logging { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 176f84fa2a0d2..b159200d79222 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -368,8 +368,8 @@ private[spark] class ExternalSorter[K, V, C]( val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) type Iter = BufferedIterator[Product2[K, C]] val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { - // Use the reverse of comparator.compare because PriorityQueue dequeues the max - override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) + // Use the reverse order because PriorityQueue dequeues the max + override def compare(x: Iter, y: Iter): Int = comparator.compare(y.head._1, x.head._1) }) heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true new Iterator[Product2[K, C]] { diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 7367af7888bd8..700ce56466c35 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -63,10 +63,19 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - while (bytes.remaining() > 0) { + val originalLimit = bytes.limit() + while (bytes.hasRemaining) { + // If `bytes` is an on-heap ByteBuffer, the Java NIO API will copy it to a temporary direct + // ByteBuffer when writing it out. This temporary direct ByteBuffer is cached per thread. + // Its size has no limit and can keep growing if it sees a larger input ByteBuffer. This may + // cause significant native memory leak, if a large direct ByteBuffer is allocated and + // cached, as it's never released until thread exits. Here we write the `bytes` with + // fixed-size slices to limit the size of the cached direct ByteBuffer. + // Please refer to http://www.evanjones.ca/java-bytebuffer-leak.html for more details. val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) bytes.limit(bytes.position() + ioSize) channel.write(bytes) + bytes.limit(originalLimit) } } } diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java index db91329c94cb6..0bbaea6b834b8 100644 --- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -17,6 +17,10 @@ package org.apache.spark.memory; +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.unsafe.memory.MemoryBlock; + import java.io.IOException; public class TestMemoryConsumer extends MemoryConsumer { @@ -43,6 +47,12 @@ void free(long size) { used -= size; taskMemoryManager.releaseExecutionMemory(size, this); } + + @VisibleForTesting + public void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index c145532328514..85ffdca436e14 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -129,7 +129,6 @@ public int compare( final UnsafeSorterIterator iter = sorter.getSortedIterator(); int iterLength = 0; long prevPrefix = -1; - Arrays.sort(dataToSort); while (iter.hasNext()) { iter.loadNext(); final String str = diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 3990ee1ec326d..5d0ffd92647bc 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -209,10 +209,8 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex System.gc() assert(ref.get.isEmpty) - // Getting a garbage collected accum should throw error - intercept[IllegalStateException] { - AccumulatorContext.get(accId) - } + // Getting a garbage collected accum should return None. + assert(AccumulatorContext.get(accId).isEmpty) // Getting a normal accumulator. Note: this has to be separate because referencing an // accumulator above in an `assert` would keep it from being garbage collected. diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 9807d1269e3d4..3cfb0a9feb32b 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -145,6 +145,39 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) } + def testAllocationRatio(cores: Int, divisor: Double, expected: Int): Unit = { + val conf = new SparkConf() + .setMaster("myDummyLocalExternalClusterManager") + .setAppName("test-executor-allocation-manager") + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .set("spark.dynamicAllocation.maxExecutors", "15") + .set("spark.dynamicAllocation.minExecutors", "3") + .set("spark.dynamicAllocation.executorAllocationRatio", divisor.toString) + .set("spark.executor.cores", cores.toString) + val sc = new SparkContext(conf) + contexts += sc + var manager = sc.executorAllocationManager.get + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 20))) + for (i <- 0 to 5) { + addExecutors(manager) + } + assert(numExecutorsTarget(manager) === expected) + sc.stop() + } + + test("executionAllocationRatio is correctly handled") { + testAllocationRatio(1, 0.5, 10) + testAllocationRatio(1, 1.0/3.0, 7) + testAllocationRatio(2, 1.0/3.0, 4) + testAllocationRatio(1, 0.385, 8) + + // max/min executors capping + testAllocationRatio(1, 1.0, 15) // should be 20 but capped by max + testAllocationRatio(4, 1.0/3.0, 3) // should be 2 but elevated by min + } + + test("add executors capped by num pending tasks") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 88916488c0def..b705556e54b14 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark import java.util.concurrent.{ExecutorService, TimeUnit} -import scala.collection.Map import scala.collection.mutable import scala.concurrent.Future import scala.concurrent.duration._ @@ -73,6 +72,7 @@ class HeartbeatReceiverSuite sc = spy(new SparkContext(conf)) scheduler = mock(classOf[TaskSchedulerImpl]) when(sc.taskScheduler).thenReturn(scheduler) + when(scheduler.nodeBlacklist).thenReturn(Predef.Set[String]()) when(scheduler.sc).thenReturn(sc) heartbeatReceiverClock = new ManualClock heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) @@ -241,7 +241,7 @@ class HeartbeatReceiverSuite } === Some(true)) } - private def getTrackedExecutors: Map[String, Long] = { + private def getTrackedExecutors: collection.Map[String, Long] = { // We may receive undesired SparkListenerExecutorAdded from LocalSchedulerBackend, // so exclude it from the map. See SPARK-10800. heartbeatReceiver.invokePrivate(_executorLastSeen()). @@ -272,7 +272,7 @@ private class FakeSchedulerBackend( protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { clusterManagerEndpoint.ask[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty[String])) + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty)) } protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 50b8ea754d8d9..21f481d477242 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -147,7 +147,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) @@ -298,4 +298,33 @@ class MapOutputTrackerSuite extends SparkFunSuite { } } + test("zero-sized blocks should be excluded when getMapSizesByExecutorId") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 2) + + val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L)) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(size0, size1000, size0, size10000))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(size10000, size0, size1000, size0))) + assert(tracker.containsShuffle(10)) + assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + Seq( + (BlockManagerId("a", "hostA", 1000), + Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), + (BlockManagerId("b", "hostB", 1000), + Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) + ) + ) + + tracker.unregisterShuffle(10) + tracker.stop() + rpcEnv.shutdown() + } + } diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 8eabc2b3cb958..5dbfc5c10a6f8 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -18,8 +18,11 @@ package org.apache.spark import java.io.File +import java.util.UUID import javax.net.ssl.SSLContext +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.alias.{CredentialProvider, CredentialProviderFactory} import org.scalatest.BeforeAndAfterAll import org.apache.spark.util.SparkConfWithEnv @@ -40,6 +43,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { .toSet val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -49,7 +53,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(",")) conf.set("spark.ssl.protocol", "TLSv1.2") - val opts = SSLOptions.parse(conf, "spark.ssl") + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl") assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -70,6 +74,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -80,8 +85,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ssl.protocol", "SSLv3") - val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -103,6 +108,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.ui.enabled", "false") conf.set("spark.ssl.ui.port", "4242") @@ -117,8 +123,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.ui.enabledAlgorithms", "ABC, DEF") conf.set("spark.ssl.protocol", "SSLv3") - val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === false) assert(opts.port === Some(4242)) @@ -139,14 +145,71 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConfWithEnv(Map( "ENV1" -> "val1", "ENV2" -> "val2")) + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", "${env:ENV1}") conf.set("spark.ssl.trustStore", "${env:ENV2}") - val opts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) assert(opts.keyStore === Some(new File("val1"))) assert(opts.trustStore === Some(new File("val2"))) } + test("get password from Hadoop credential provider") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + val hadoopConf = new Configuration() + val tmpPath = s"localjceks://file${sys.props("java.io.tmpdir")}/test-" + + s"${UUID.randomUUID().toString}.jceks" + val provider = createCredentialProvider(tmpPath, hadoopConf) + + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + storePassword(provider, "spark.ssl.keyStorePassword", "password") + storePassword(provider, "spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + storePassword(provider, "spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.protocol", "SSLv3") + + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) + + assert(opts.enabled === true) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("password")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + + private def createCredentialProvider(tmpPath: String, conf: Configuration): CredentialProvider = { + conf.set(CredentialProviderFactory.CREDENTIAL_PROVIDER_PATH, tmpPath) + + val provider = CredentialProviderFactory.getProviders(conf).get(0) + if (provider == null) { + throw new IllegalStateException(s"Fail to get credential provider with path $tmpPath") + } + + provider + } + + private def storePassword( + provider: CredentialProvider, + passwordKey: String, + password: String): Unit = { + provider.createCredentialEntry(passwordKey, password.toCharArray) + provider.flush() + } } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index bff808eb540ac..0d06b02e74e34 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -339,6 +339,38 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } + val defaultIllegalValue = "SomeIllegalValue" + val illegalValueTests : Map[String, (SparkConf, String) => Any] = Map( + "getTimeAsSeconds" -> (_.getTimeAsSeconds(_)), + "getTimeAsSeconds with default" -> (_.getTimeAsSeconds(_, defaultIllegalValue)), + "getTimeAsMs" -> (_.getTimeAsMs(_)), + "getTimeAsMs with default" -> (_.getTimeAsMs(_, defaultIllegalValue)), + "getSizeAsBytes" -> (_.getSizeAsBytes(_)), + "getSizeAsBytes with default string" -> (_.getSizeAsBytes(_, defaultIllegalValue)), + "getSizeAsBytes with default long" -> (_.getSizeAsBytes(_, 0L)), + "getSizeAsKb" -> (_.getSizeAsKb(_)), + "getSizeAsKb with default" -> (_.getSizeAsKb(_, defaultIllegalValue)), + "getSizeAsMb" -> (_.getSizeAsMb(_)), + "getSizeAsMb with default" -> (_.getSizeAsMb(_, defaultIllegalValue)), + "getSizeAsGb" -> (_.getSizeAsGb(_)), + "getSizeAsGb with default" -> (_.getSizeAsGb(_, defaultIllegalValue)), + "getInt" -> (_.getInt(_, 0)), + "getLong" -> (_.getLong(_, 0L)), + "getDouble" -> (_.getDouble(_, 0.0)), + "getBoolean" -> (_.getBoolean(_, false)) + ) + + illegalValueTests.foreach { case (name, getValue) => + test(s"SPARK-24337: $name throws an useful error message with key name") { + val key = "SomeKey" + val conf = new SparkConf() + conf.set(key, "SomeInvalidValue") + val thrown = intercept[IllegalArgumentException] { + getValue(conf, key) + } + assert(thrown.getMessage.contains(key)) + } + } } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7451e07b25a1f..f829fecc30840 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -180,6 +180,26 @@ class SparkSubmitSuite appArgs.toString should include ("thequeue") } + test("SPARK-24241: do not fail fast if executor num is 0 when dynamic allocation is enabled") { + val clArgs1 = Seq( + "--name", "myApp", + "--class", "Foo", + "--num-executors", "0", + "--conf", "spark.dynamicAllocation.enabled=true", + "thejar.jar") + new SparkSubmitArguments(clArgs1) + + val clArgs2 = Seq( + "--name", "myApp", + "--class", "Foo", + "--num-executors", "0", + "--conf", "spark.dynamicAllocation.enabled=false", + "thejar.jar") + + val e = intercept[SparkException](new SparkSubmitArguments(clArgs2)) + assert(e.getMessage.contains("Number of executors must be a positive number")) + } + test("specify deploy mode through configuration") { val clArgs = Seq( "--master", "yarn", @@ -751,9 +771,13 @@ class SparkSubmitSuite PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) // Test remote python files + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) val f4 = File.createTempFile("test-submit-remote-python-files", "", tmpDir) + val pyFile1 = File.createTempFile("file1", ".py", tmpDir) + val pyFile2 = File.createTempFile("file2", ".py", tmpDir) val writer4 = new PrintWriter(f4) - val remotePyFiles = "hdfs:///tmp/file1.py,hdfs:///tmp/file2.py" + val remotePyFiles = s"s3a://${pyFile1.getAbsolutePath},s3a://${pyFile2.getAbsolutePath}" writer4.println("spark.submit.pyFiles " + remotePyFiles) writer4.close() val clArgs4 = Seq( @@ -763,7 +787,7 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4) + val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4, conf = Some(hadoopConf)) // Should not format python path for yarn cluster mode conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } @@ -971,20 +995,24 @@ class SparkSubmitSuite } test("download remote resource if it is not supported by yarn service") { - testRemoteResources(enableHttpFs = false, blacklistHttpFs = false) + testRemoteResources(enableHttpFs = false) } test("avoid downloading remote resource if it is supported by yarn service") { - testRemoteResources(enableHttpFs = true, blacklistHttpFs = false) + testRemoteResources(enableHttpFs = true) } test("force download from blacklisted schemes") { - testRemoteResources(enableHttpFs = true, blacklistHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("http")) + } + + test("force download for all the schemes") { + testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("*")) } private def testRemoteResources( enableHttpFs: Boolean, - blacklistHttpFs: Boolean): Unit = { + blacklistSchemes: Seq[String] = Nil): Unit = { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) if (enableHttpFs) { @@ -1001,8 +1029,8 @@ class SparkSubmitSuite val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" - val forceDownloadArgs = if (blacklistHttpFs) { - Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http") + val forceDownloadArgs = if (blacklistSchemes.nonEmpty) { + Seq("--conf", s"spark.yarn.dist.forceDownloadSchemes=${blacklistSchemes.mkString(",")}") } else { Nil } @@ -1020,14 +1048,19 @@ class SparkSubmitSuite val jars = conf.get("spark.yarn.dist.jars").split(",").toSet - // The URI of remote S3 resource should still be remote. - assert(jars.contains(tmpS3JarPath)) + def isSchemeBlacklisted(scheme: String) = { + blacklistSchemes.contains("*") || blacklistSchemes.contains(scheme) + } + + if (!isSchemeBlacklisted("s3")) { + assert(jars.contains(tmpS3JarPath)) + } - if (enableHttpFs && !blacklistHttpFs) { + if (enableHttpFs && blacklistSchemes.isEmpty) { // If Http FS is supported by yarn service, the URI of remote http resource should // still be remote. assert(jars.contains(tmpHttpJarPath)) - } else { + } else if (!enableHttpFs || isSchemeBlacklisted("http")) { // If Http FS is not supported by yarn service, or http scheme is configured to be force // downloading, the URI of remote http resource should be changed to a local one. val jarName = new File(tmpHttpJar.toURI).getName @@ -1073,6 +1106,44 @@ class SparkSubmitSuite assert(exception.getMessage() === "hello") } + test("support --py-files/spark.submit.pyFiles in non pyspark application") { + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) + + val tmpDir = Utils.createTempDir() + val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir) + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--py-files", s"s3a://${pyFile.getAbsolutePath}", + "spark-internal" + ) + + val appArgs = new SparkSubmitArguments(args) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) + + conf.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") + conf.get("spark.submit.pyFiles") should (startWith("/")) + + // Verify "spark.submit.pyFiles" + val args1 = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--conf", s"spark.submit.pyFiles=s3a://${pyFile.getAbsolutePath}", + "spark-internal" + ) + + val appArgs1 = new SparkSubmitArguments(args1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1, conf = Some(hadoopConf)) + + conf1.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") + conf1.get("spark.submit.pyFiles") should (startWith("/")) + } } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index eb8c203ae7751..a0f09891787e0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -256,4 +256,19 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(jarPath.indexOf("mydep") >= 0, "should find dependency") } } + + test("SPARK-10878: test resolution files cleaned after resolving artifact") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") + + IvyTestUtils.withRepository(main, None, None) { repo => + val ivySettings = SparkSubmitUtils.buildIvySettings(Some(repo), Some(tempIvyPath)) + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + ivySettings, + isTest = true) + val r = """.*org.apache.spark-spark-submit-parent-.*""".r + assert(!ivySettings.getDefaultCache.listFiles.map(_.getName) + .exists(r.findFirstIn(_).isDefined), "resolution files should be cleaned") + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 87f12f303cd5e..11b29121739a4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -36,6 +36,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods._ +import org.mockito.Mockito._ import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest.{BeforeAndAfter, Matchers} @@ -281,6 +282,29 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } + test("automatically retrieve uiRoot from request through Knox") { + assert(sys.props.get("spark.ui.proxyBase").isEmpty, + "spark.ui.proxyBase is defined but it should not for this UT") + assert(sys.env.get("APPLICATION_WEB_PROXY_BASE").isEmpty, + "APPLICATION_WEB_PROXY_BASE is defined but it should not for this UT") + val page = new HistoryPage(server) + val requestThroughKnox = mock[HttpServletRequest] + val knoxBaseUrl = "/gateway/default/sparkhistoryui" + when(requestThroughKnox.getHeader("X-Forwarded-Context")).thenReturn(knoxBaseUrl) + val responseThroughKnox = page.render(requestThroughKnox) + + val urlsThroughKnox = responseThroughKnox \\ "@href" map (_.toString) + val siteRelativeLinksThroughKnox = urlsThroughKnox filter (_.startsWith("/")) + all (siteRelativeLinksThroughKnox) should startWith (knoxBaseUrl) + + val directRequest = mock[HttpServletRequest] + val directResponse = page.render(directRequest) + + val directUrls = directResponse \\ "@href" map (_.toString) + val directSiteRelativeLinks = directUrls filter (_.startsWith("/")) + all (directSiteRelativeLinks) should not startWith (knoxBaseUrl) + } + test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") val page = new HistoryPage(server) @@ -296,6 +320,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (siteRelativeLinks) should startWith (uiRoot) } + test("/version api endpoint") { + val response = getUrl("version") + assert(response.contains(SPARK_VERSION)) + } + test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = "/testwebproxybase" System.setProperty("spark.ui.proxyBase", uiRoot) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index ce212a7513310..e3fe2b696aa1f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,10 +17,19 @@ package org.apache.spark.deploy.worker +import java.util.concurrent.atomic.AtomicBoolean +import java.util.function.Supplier + +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.deploy.{Command, ExecutorState} +import org.apache.spark.deploy.{Command, ExecutorState, ExternalShuffleService} import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} import org.apache.spark.deploy.master.DriverState import org.apache.spark.rpc.{RpcAddress, RpcEnv} @@ -29,6 +38,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { import org.apache.spark.deploy.DeployTestUtils._ + @Mock(answer = RETURNS_SMART_NULLS) private var shuffleService: ExternalShuffleService = _ + def cmd(javaOpts: String*): Command = { Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) } @@ -36,15 +47,21 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { private var _worker: Worker = _ - private def makeWorker(conf: SparkConf): Worker = { + private def makeWorker( + conf: SparkConf, + shuffleServiceSupplier: Supplier[ExternalShuffleService] = null): Worker = { assert(_worker === null, "Some Worker's RpcEnv is leaked in tests") val securityMgr = new SecurityManager(conf) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, securityMgr) _worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, securityMgr) + "Worker", "/tmp", conf, securityMgr, shuffleServiceSupplier) _worker } + before { + MockitoAnnotations.initMocks(this) + } + after { if (_worker != null) { _worker.rpcEnv.shutdown() @@ -194,4 +211,36 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { assert(worker.finishedDrivers.size === expectedValue) } } + + test("cleanup non-shuffle files after executor exits when config " + + "spark.storage.cleanupFilesAfterExecutorExit=true") { + testCleanupFilesWithConfig(true) + } + + test("don't cleanup non-shuffle files after executor exits when config " + + "spark.storage.cleanupFilesAfterExecutorExit=false") { + testCleanupFilesWithConfig(false) + } + + private def testCleanupFilesWithConfig(value: Boolean) = { + val conf = new SparkConf().set("spark.storage.cleanupFilesAfterExecutorExit", value.toString) + + val cleanupCalled = new AtomicBoolean(false) + when(shuffleService.executorRemoved(any[String], any[String])).thenAnswer(new Answer[Unit] { + override def answer(invocations: InvocationOnMock): Unit = { + cleanupCalled.set(true) + } + }) + val externalShuffleServiceSupplier = new Supplier[ExternalShuffleService] { + override def get: ExternalShuffleService = shuffleService + } + val worker = makeWorker(conf, externalShuffleServiceSupplier) + // initialize workers + for (i <- 0 until 10) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(cleanupCalled.get() == value) + } } diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala new file mode 100644 index 0000000000000..817dc082b7d38 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.collection.immutable.IndexedSeq + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Tests the correctness of + * [[org.apache.spark.input.WholeTextFileInputFormat WholeTextFileInputFormat]]. A temporary + * directory containing files is created as fake input which is deleted in the end. + */ +class WholeTextFileInputFormatSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { + private var sc: SparkContext = _ + + override def beforeAll() { + super.beforeAll() + val conf = new SparkConf() + sc = new SparkContext("local", "test", conf) + } + + override def afterAll() { + try { + sc.stop() + } finally { + super.afterAll() + } + } + + private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte], + compress: Boolean) = { + val path = s"${inputDir.toString}/$fileName" + val out = new DataOutputStream(new FileOutputStream(path)) + out.write(contents, 0, contents.length) + out.close() + } + + test("for small files minimum split size per node and per rack should be less than or equal to " + + "maximum split size.") { + var dir : File = null; + try { + dir = Utils.createTempDir() + logInfo(s"Local disk address is ${dir.toString}.") + + // Set the minsize per node and rack to be larger than the size of the input file. + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.node", 123456) + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.rack", 123456) + + WholeTextFileInputFormatSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents, false) + } + // ensure spark job runs successfully without exceptions from the CombineFileInputFormat + assert(sc.wholeTextFiles(dir.toString).count == 3) + } finally { + Utils.deleteRecursively(dir) + } + } +} + +/** + * Files to be tested are defined here. + */ +object WholeTextFileInputFormatSuite { + private val testWords: IndexedSeq[Byte] = "Spark is easy to use.\n".map(_.toByte) + + private val fileNames = Array("part-00000", "part-00001", "part-00002") + private val fileLengths = Array(10, 100, 1000) + + private val files = fileLengths.zip(fileNames).map { case (upperBound, filename) => + filename -> Stream.continually(testWords.toList.toStream).flatten.take(upperBound).toArray + }.toMap +} diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 3b798e36b0499..2107559572d78 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -21,11 +21,12 @@ import java.nio.ByteBuffer import com.google.common.io.ByteStreams -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.util.io.ChunkedByteBuffer -class ChunkedByteBufferSuite extends SparkFunSuite { +class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { test("no chunks") { val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer]) @@ -56,6 +57,18 @@ class ChunkedByteBufferSuite extends SparkFunSuite { assert(chunkedByteBuffer.getChunks().head.position() === 0) } + test("SPARK-24107: writeFully() write buffer which is larger than bufferWriteChunkSize") { + try { + sc.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 32L * 1024L * 1024L) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(40 * 1024 * 1024))) + val byteArrayWritableChannel = new ByteArrayWritableChannel(chunkedByteBuffer.size.toInt) + chunkedByteBuffer.writeFully(byteArrayWritableChannel) + assert(byteArrayWritableChannel.length() === chunkedByteBuffer.size) + } finally { + sc.conf.remove(config.BUFFER_WRITE_CHUNK_SIZE) + } + } + test("toArray()") { val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 191c61250ce21..5148ce05bd918 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -154,6 +154,16 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("SPARK-23778: empty RDD in union should not produce a UnionRDD") { + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val emptyRDD = sc.emptyRDD[(Int, Boolean)] + val unionRDD = sc.union(emptyRDD, rddWithPartitioner) + assert(unionRDD.isInstanceOf[PartitionerAwareUnionRDD[_]]) + val unionAllEmptyRDD = sc.union(emptyRDD, emptyRDD) + assert(unionAllEmptyRDD.isInstanceOf[UnionRDD[_]]) + assert(unionAllEmptyRDD.collect().isEmpty) + } + test("partitioner aware union") { def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) @@ -1047,7 +1057,9 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { private class CyclicalDependencyRDD[T: ClassTag] extends RDD[T](sc, Nil) { private val mutableDependencies: ArrayBuffer[Dependency[_]] = ArrayBuffer.empty override def compute(p: Partition, c: TaskContext): Iterator[T] = Iterator.empty - override def getPartitions: Array[Partition] = Array.empty + override def getPartitions: Array[Partition] = Array(new Partition { + override def index: Int = 0 + }) override def getDependencies: Seq[Dependency[_]] = mutableDependencies def addDependency(dep: Dependency[_]) { mutableDependencies += dep diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8b6ec37625eec..2987170bf5026 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1852,7 +1852,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } - test("accumulators are updated on exception failures") { + test("accumulators are updated on exception failures and task killed") { val acc1 = AccumulatorSuite.createLongAccum("ingenieur") val acc2 = AccumulatorSuite.createLongAccum("boulanger") val acc3 = AccumulatorSuite.createLongAccum("agriculteur") @@ -1868,15 +1868,24 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val accUpdate3 = new LongAccumulator accUpdate3.metadata = acc3.metadata accUpdate3.setValue(18) - val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3) - val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo) + + val accumUpdates1 = Seq(accUpdate1, accUpdate2) + val accumInfo1 = accumUpdates1.map(AccumulatorSuite.makeInfo) val exceptionFailure = new ExceptionFailure( new SparkException("fondue?"), - accumInfo).copy(accums = accumUpdates) + accumInfo1).copy(accums = accumUpdates1) submit(new MyRDD(sc, 1, Nil), Array(0)) runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result")) + assert(AccumulatorContext.get(acc1.id).get.value === 15L) assert(AccumulatorContext.get(acc2.id).get.value === 13L) + + val accumUpdates2 = Seq(accUpdate3) + val accumInfo2 = accumUpdates2.map(AccumulatorSuite.makeInfo) + + val taskKilled = new TaskKilled( "test", accumInfo2, accums = accumUpdates2) + runEvent(makeCompletionEvent(taskSets.head.tasks.head, taskKilled, "result")) + assert(AccumulatorContext.get(acc3.id).get.value === 18L) } @@ -2497,6 +2506,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val accumUpdates = reason match { case Success => task.metrics.accumulators() case ef: ExceptionFailure => ef.accums + case tk: TaskKilled => tk.accums case _ => Seq.empty } CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 2155a0f2b6c21..354e6386fa60e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -188,4 +188,32 @@ class MapStatusSuite extends SparkFunSuite { assert(count === 3000) } } + + test("SPARK-24519: HighlyCompressedMapStatus has configurable threshold") { + val conf = new SparkConf() + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + val sizes = Array.fill[Long](500)(150L) + // Test default value + val status = MapStatus(null, sizes) + assert(status.isInstanceOf[CompressedMapStatus]) + // Test Non-positive values + for (s <- -1 to 0) { + assertThrows[IllegalArgumentException] { + conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) + val status = MapStatus(null, sizes) + } + } + // Test positive values + Seq(1, 100, 499, 500, 501).foreach { s => + conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) + val status = MapStatus(null, sizes) + if(sizes.length > s) { + assert(status.isInstanceOf[HighlyCompressedMapStatus]) + } else { + assert(status.isInstanceOf[CompressedMapStatus]) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 03b1903902491..158c9eb75f2b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.rdd.{FakeOutputCommitter, RDD} +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -153,7 +154,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Job should not complete if all commits are denied") { // Create a mock OutputCommitCoordinator that denies all attempts to commit doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit( - Matchers.any(), Matchers.any(), Matchers.any()) + Matchers.any(), Matchers.any(), Matchers.any(), Matchers.any()) val rdd: RDD[Int] = sc.parallelize(Seq(1), 1) def resultHandler(x: Int, y: Unit): Unit = {} val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd, @@ -169,45 +170,106 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { val stage: Int = 1 + val stageAttempt: Int = 1 val partition: Int = 2 val authorizedCommitter: Int = 3 val nonAuthorizedCommitter: Int = 100 outputCommitCoordinator.stageStart(stage, maxPartitionId = 2) - assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) - assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter)) // The non-authorized committer fails - outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) // New tasks should still not be able to commit because the authorized committer has not failed - assert( - !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock - outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test")) + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = authorizedCommitter, reason = TaskKilled("test")) // A new task should now be allowed to become the authorized committer - assert( - outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 2)) // There can only be one authorized committer - assert( - !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) - } - - test("Duplicate calls to canCommit from the authorized committer gets idempotent responses.") { - val rdd = sc.parallelize(Seq(1), 1) - sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).callCanCommitMultipleTimes _, - 0 until rdd.partitions.size) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 3)) } test("SPARK-19631: Do not allow failed attempts to be authorized for committing") { val stage: Int = 1 + val stageAttempt: Int = 1 val partition: Int = 1 val failedAttempt: Int = 0 outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) - outputCommitCoordinator.taskCompleted(stage, partition, attemptNumber = failedAttempt, + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = failedAttempt, reason = ExecutorLostFailure("0", exitCausedByApp = true, None)) - assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt)) - assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt + 1)) + } + + test("SPARK-24589: Differentiate tasks from different stage attempts") { + var stage = 1 + val taskAttempt = 1 + val partition = 1 + + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + assert(!outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) + + // Fail the task in the first attempt, the task in the second attempt should succeed. + stage += 1 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + assert(outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) + + // Commit the 1st attempt, fail the 2nd attempt, make sure 3rd attempt cannot commit, + // then fail the 1st attempt and make sure the 4th one can commit again. + stage += 1 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + outputCommitCoordinator.taskCompleted(stage, 2, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, 3, partition, taskAttempt)) + outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(outputCommitCoordinator.canCommit(stage, 4, partition, taskAttempt)) + } + + test("SPARK-24589: Make sure stage state is cleaned up") { + // Normal application without stage failures. + sc.parallelize(1 to 100, 100) + .map { i => (i % 10, i) } + .reduceByKey(_ + _) + .collect() + + assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) + + // Force failures in a few tasks so that a stage is retried. Collect the ID of the failing + // stage so that we can check the state of the output committer. + val retriedStage = sc.parallelize(1 to 100, 10) + .map { i => (i % 10, i) } + .reduceByKey { case (_, _) => + val ctx = TaskContext.get() + if (ctx.stageAttemptNumber() == 0) { + throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 1, 1, 1, + new Exception("Failure for test.")) + } else { + ctx.stageId() + } + } + .collect() + .map { case (k, v) => v } + .toSet + + assert(retriedStage.size === 1) + assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) + verify(sc.env.outputCommitCoordinator, times(2)) + .stageStart(Matchers.eq(retriedStage.head), Matchers.any()) + verify(sc.env.outputCommitCoordinator).stageEnd(Matchers.eq(retriedStage.head)) } } @@ -243,16 +305,6 @@ private case class OutputCommitFunctions(tempDirPath: String) { if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter) } - // Receiver should be idempotent for AskPermissionToCommitOutput - def callCanCommitMultipleTimes(iter: Iterator[Int]): Unit = { - val ctx = TaskContext.get() - val canCommit1 = SparkEnv.get.outputCommitCoordinator - .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) - val canCommit2 = SparkEnv.get.outputCommitCoordinator - .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) - assert(canCommit1 && canCommit2) - } - private def runCommitWithProvidedCommitter( ctx: TaskContext, iter: Iterator[Int], diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index da6ecb82c7e42..6ffd1e84f7adb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.Semaphore import scala.collection.JavaConverters._ @@ -294,10 +295,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) - // just to make sure some of the tasks take a noticeable amount of time + // just to make sure some of the tasks and their deserialization take a noticeable + // amount of time + val slowDeserializable = new SlowDeserializable val w = { i: Int => if (i == 0) { Thread.sleep(100) + slowDeserializable.use() } i } @@ -485,6 +489,48 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(bus.findListenersByClass[BasicJobCounter]().isEmpty) } + Seq(true, false).foreach { throwInterruptedException => + val suffix = if (throwInterruptedException) "throw interrupt" else "set Thread interrupted" + test(s"interrupt within listener is handled correctly: $suffix") { + val conf = new SparkConf(false) + .set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 5) + val bus = new LiveListenerBus(conf) + val counter1 = new BasicJobCounter() + val counter2 = new BasicJobCounter() + val interruptingListener1 = new InterruptingListener(throwInterruptedException) + val interruptingListener2 = new InterruptingListener(throwInterruptedException) + bus.addToSharedQueue(counter1) + bus.addToSharedQueue(interruptingListener1) + bus.addToStatusQueue(counter2) + bus.addToEventLogQueue(interruptingListener2) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE, EVENT_LOG_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 2) + + bus.start(mockSparkContext, mockMetricsSystem) + + // after we post one event, both interrupting listeners should get removed, and the + // event log queue should be removed + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 0) + assert(counter1.count === 1) + assert(counter2.count === 1) + + // posting more events should be fine, they'll just get processed from the OK queue. + (0 until 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(counter1.count === 6) + assert(counter2.count === 6) + + // Make sure stopping works -- this requires putting a poison pill in all active queues, which + // would fail if our interrupted queue was still active, as its queue would be full. + bus.stop() + } + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ @@ -543,6 +589,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception } } + /** + * A simple listener that interrupts on job end. + */ + private class InterruptingListener(val throwInterruptedException: Boolean) extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (throwInterruptedException) { + throw new InterruptedException("got interrupted") + } else { + Thread.currentThread().interrupt() + } + } + } } // These classes can't be declared inside of the SparkListenerSuite class because we don't want @@ -583,3 +641,12 @@ private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends Spar case _ => } } + +private class SlowDeserializable extends Externalizable { + + override def writeExternal(out: ObjectOutput): Unit = { } + + override def readExternal(in: ObjectInput): Unit = Thread.sleep(1) + + def use(): Unit = { } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 6003899bb7bef..33f2ea1c94e75 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -917,4 +917,108 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.initialize(new FakeSchedulerBackend) } } + + test("Completions in zombie tasksets update status of non-zombie taskset") { + val taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + val valueSer = SparkEnv.get.serializer.newInstance() + + def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = { + val indexInTsm = tsm.partitionToIndex(partition) + val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result) + } + + // Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt, + // two times, so we have three active task sets for one stage. (For this to really happen, + // you'd need the previous stage to also get restarted, and then succeed, in between each + // attempt, but that happens outside what we're mocking here.) + val zombieAttempts = (0 until 2).map { stageAttempt => + val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt) + taskScheduler.submitTasks(attempt) + val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get + val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + taskScheduler.resourceOffers(offers) + assert(tsm.runningTasks === 10) + // fail attempt + tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + // the attempt is a zombie, but the tasks are still running (this could be true even if + // we actively killed those tasks, as killing is best-effort) + assert(tsm.isZombie) + assert(tsm.runningTasks === 9) + tsm + } + + // we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for + // the stage, but this time with insufficient resources so not all tasks are active. + + val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2) + taskScheduler.submitTasks(finalAttempt) + val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get + val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task => + finalAttempt.tasks(task.index).partitionId + }.toSet + assert(finalTsm.runningTasks === 5) + assert(!finalTsm.isZombie) + + // We simulate late completions from our zombie tasksets, corresponding to all the pending + // partitions in our final attempt. This means we're only waiting on the tasks we've already + // launched. + val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions) + finalAttemptPendingPartitions.foreach { partition => + completeTaskSuccessfully(zombieAttempts(0), partition) + } + + // If there is another resource offer, we shouldn't run anything. Though our final attempt + // used to have pending tasks, now those tasks have been completed by zombie attempts. The + // remaining tasks to compute are already active in the non-zombie attempt. + assert( + taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty) + + val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted + + // finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be + // marked as zombie. + // for each of the remaining tasks, find the tasksets with an active copy of the task, and + // finish the task. + remainingTasks.foreach { partition => + val tsm = if (partition == 0) { + // we failed this task on both zombie attempts, this one is only present in the latest + // taskset + finalTsm + } else { + // should be active in every taskset. We choose a zombie taskset just to make sure that + // we transition the active taskset correctly even if the final completion comes + // from a zombie. + zombieAttempts(partition % 2) + } + completeTaskSuccessfully(tsm, partition) + } + + assert(finalTsm.isZombie) + + // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet + verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject()) + + // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything + // else succeeds, to make sure we get the right updates to the blacklist in all cases. + (zombieAttempts ++ Seq(finalTsm)).foreach { tsm => + val stageAttempt = tsm.taskSet.stageAttemptId + tsm.runningTasksSet.foreach { index => + if (stageAttempt == 1) { + tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost) + } else { + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result) + } + } + + // we update the blacklist for the stage attempts with all successful tasks. Even though + // some tasksets had failures, we still consider them all successful from a blacklisting + // perspective, as the failures weren't from a problem w/ the tasks themselves. + verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject()) + } + } } diff --git a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala new file mode 100644 index 0000000000000..e57cb701b6284 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import java.io.Closeable +import java.net._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +class SocketAuthHelperSuite extends SparkFunSuite { + + private val conf = new SparkConf() + private val authHelper = new SocketAuthHelper(conf) + + test("successful auth") { + Utils.tryWithResource(new ServerThread()) { server => + Utils.tryWithResource(server.createClient()) { client => + authHelper.authToServer(client) + server.close() + server.join() + assert(server.error == null) + assert(server.authenticated) + } + } + } + + test("failed auth") { + Utils.tryWithResource(new ServerThread()) { server => + Utils.tryWithResource(server.createClient()) { client => + val badHelper = new SocketAuthHelper(new SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128)) + intercept[IllegalArgumentException] { + badHelper.authToServer(client) + } + server.close() + server.join() + assert(server.error != null) + assert(!server.authenticated) + } + } + } + + private class ServerThread extends Thread with Closeable { + + private val ss = new ServerSocket() + ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)) + + @volatile var error: Exception = _ + @volatile var authenticated = false + + setDaemon(true) + start() + + def createClient(): Socket = { + new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort()) + } + + override def run(): Unit = { + var clientConn: Socket = null + try { + clientConn = ss.accept() + authHelper.authClient(clientConn) + authenticated = true + } catch { + case e: Exception => + error = e + } finally { + Option(clientConn).foreach(_.close()) + } + } + + override def close(): Unit = { + try { + ss.close() + } finally { + interrupt() + } + } + + } + +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index dba1172d5fdbd..2d8a83c6fabed 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -108,7 +108,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) (shuffleBlockId, byteOutputStream.size().toLong) } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator } // Create a mocked shuffle handle to pass into HashShuffleReader. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b19d8ebf72c61..08172f0b07b75 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1422,6 +1422,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(mockBlockTransferService.tempFileManager === store.remoteBlockTempFileManager) } + test("query locations of blockIds") { + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val blockLocations = Seq(BlockManagerId("1", "host1", 100), BlockManagerId("2", "host2", 200)) + when(mockBlockManagerMaster.getLocations(mc.any[Array[BlockId]])) + .thenReturn(Array(blockLocations)) + val env = mock(classOf[SparkEnv]) + + val blockIds: Array[BlockId] = Array(StreamBlockId(1, 2)) + val locs = BlockManager.blockIdsToLocations(blockIds, env, mockBlockManagerMaster) + val expectedLocs = Seq("executor_host1_1", "executor_host2_2") + assert(locs(blockIds(0)) == expectedLocs) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 var tempFileManager: TempFileManager = null diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 692ae3bf597e0..a2997dbd1b1ac 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -65,12 +65,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } // Create a mock managed buffer for testing - def createMockManagedBuffer(): ManagedBuffer = { + def createMockManagedBuffer(size: Int = 1): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) val in = mock(classOf[InputStream]) when(in.read(any())).thenReturn(1) when(in.read(any(), any(), any())).thenReturn(1) when(mockManagedBuffer.createInputStream()).thenReturn(in) + when(mockManagedBuffer.size()).thenReturn(size) mockManagedBuffer } @@ -99,7 +100,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) - ) + ).toIterator val iterator = new ShuffleBlockFetcherIterator( TaskContext.empty(), @@ -176,7 +177,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -244,7 +245,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -269,6 +270,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.size()).thenReturn(size) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + corruptBuffer + } + test("retry corrupt blocks") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -284,11 +294,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) @@ -301,7 +306,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) sem.release() @@ -310,7 +315,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -339,7 +344,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) sem.release() } } @@ -353,11 +358,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("big blocks are not checked for corruption") { - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - doReturn(10000L).when(corruptBuffer).size() + val corruptBuffer = mockCorruptBuffer(10000L) val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -378,7 +379,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlockLengths), (remoteBmId, remoteBlockLengths) - ) + ).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -413,11 +414,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - val transfer = mock(classOf[BlockTransferService]) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { @@ -428,16 +424,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, corruptBuffer) + ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer()) sem.release() } } }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -495,7 +491,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + def fetchShuffleBlock( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. @@ -513,17 +510,52 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. assert(tempFileManager != null) } + + test("fail zero-size blocks") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer() + ) + + val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0))) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress.toIterator, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + + // All blocks fetched return zero length and should trigger a receive-side error: + val e = intercept[FetchFailedException] { iterator.next() } + assert(e.getMessage.contains("Received a zero-size buffer")) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index a71521c91d2f2..cdc7f541b9552 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.storage +import javax.servlet.http.HttpServletRequest + import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite @@ -29,6 +31,7 @@ class StoragePageSuite extends SparkFunSuite { val storageTab = mock(classOf[StorageTab]) when(storageTab.basePath).thenReturn("http://localhost:4040") val storagePage = new StoragePage(storageTab, null) + val request = mock(classOf[HttpServletRequest]) test("rddTable") { val rdd1 = new RDDStorageInfo(1, @@ -61,7 +64,7 @@ class StoragePageSuite extends SparkFunSuite { None, None) - val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) + val xmlNodes = storagePage.rddTable(request, Seq(rdd1, rdd2, rdd3)) val headers = Seq( "ID", @@ -94,7 +97,7 @@ class StoragePageSuite extends SparkFunSuite { } test("empty rddTable") { - assert(storagePage.rddTable(Seq.empty).isEmpty) + assert(storagePage.rddTable(request, Seq.empty).isEmpty) } test("streamBlockStorageLevelDescriptionAndSize") { diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index a04644d57ed88..fe0a9a471a651 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import org.apache.spark._ +import org.apache.spark.serializer.JavaSerializer class AccumulatorV2Suite extends SparkFunSuite { @@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc3.isZero) assert(acc3.value === "") } + + test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") { + class MyData(val i: Int) extends Serializable + val param = new AccumulatorParam[MyData] { + override def zero(initialValue: MyData): MyData = new MyData(0) + override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i) + } + + val acc = new LegacyAccumulatorWrapper(new MyData(0), param) + acc.metadata = AccumulatorMetadata( + AccumulatorContext.newId(), + Some("test"), + countFailedValues = false) + AccumulatorContext.register(acc) + + val ser = new JavaSerializer(new SparkConf).newInstance() + ser.serialize(acc) + } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 3b4273184f1e9..418d2f9b88500 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1168,6 +1168,22 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port") } } + + object MalformedClassObject { + class MalformedClass + } + + test("Safe getSimpleName") { + // getSimpleName on class of MalformedClass will result in error: Malformed class name + // Utils.getSimpleName works + val err = intercept[java.lang.InternalError] { + classOf[MalformedClassObject.MalformedClass].getSimpleName + } + assert(err.getMessage === "Malformed class name") + + assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) === + "UtilsSuite$MalformedClassObject$MalformedClass") + } } private class SimpleExtension diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 9552d001a079c..466135e72233a 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -11,6 +11,10 @@ cache .rat-excludes .*md derby.log +licenses/* +licenses-binary/* +LICENSE +NOTICE TAGS RELEASE control @@ -106,3 +110,4 @@ spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin kafka-source-initial-offset-future-version.bin +vote.tmpl diff --git a/dev/create-release/do-release-docker.sh b/dev/create-release/do-release-docker.sh new file mode 100755 index 0000000000000..fa7b73cdb40ec --- /dev/null +++ b/dev/create-release/do-release-docker.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Creates a Spark release candidate. The script will update versions, tag the branch, +# build Spark binary packages and documentation, and upload maven artifacts to a staging +# repository. There is also a dry run mode where only local builds are performed, and +# nothing is uploaded to the ASF repos. +# +# Run with "-h" for options. +# + +set -e +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + +function usage { + local NAME=$(basename $0) + cat < "$GPG_KEY_FILE" + +run_silent "Building spark-rm image with tag $IMGTAG..." "docker-build.log" \ + docker build -t "spark-rm:$IMGTAG" --build-arg UID=$UID "$SELF/spark-rm" + +# Write the release information to a file with environment variables to be used when running the +# image. +ENVFILE="$WORKDIR/env.list" +fcreate_secure "$ENVFILE" + +function cleanup { + rm -f "$ENVFILE" + rm -f "$GPG_KEY_FILE" +} + +trap cleanup EXIT + +cat > $ENVFILE <> $ENVFILE + JAVA_VOL="--volume $JAVA:/opt/spark-java" +fi + +echo "Building $RELEASE_TAG; output will be at $WORKDIR/output" +docker run -ti \ + --env-file "$ENVFILE" \ + --volume "$WORKDIR:/opt/spark-rm" \ + $JAVA_VOL \ + "spark-rm:$IMGTAG" diff --git a/dev/create-release/do-release.sh b/dev/create-release/do-release.sh new file mode 100755 index 0000000000000..f1d4f3ab5ddec --- /dev/null +++ b/dev/create-release/do-release.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + +while getopts "bn" opt; do + case $opt in + b) GIT_BRANCH=$OPTARG ;; + n) DRY_RUN=1 ;; + ?) error "Invalid option: $OPTARG" ;; + esac +done + +if [ "$RUNNING_IN_DOCKER" = "1" ]; then + # Inside docker, need to import the GPG key stored in the current directory. + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --import "$SELF/gpg.key" + + # We may need to adjust the path since JAVA_HOME may be overridden by the driver script. + if [ -n "$JAVA_HOME" ]; then + export PATH="$JAVA_HOME/bin:$PATH" + else + # JAVA_HOME for the openjdk package. + export JAVA_HOME=/usr + fi +else + # Outside docker, need to ask for information about the release. + get_release_info +fi + +function should_build { + local WHAT=$1 + [ -z "$RELEASE_STEP" ] || [ "$WHAT" = "$RELEASE_STEP" ] +} + +if should_build "tag" && [ $SKIP_TAG = 0 ]; then + run_silent "Creating release tag $RELEASE_TAG..." "tag.log" \ + "$SELF/release-tag.sh" + echo "It may take some time for the tag to be synchronized to github." + echo "Press enter when you've verified that the new tag ($RELEASE_TAG) is available." + read +else + echo "Skipping tag creation for $RELEASE_TAG." +fi + +if should_build "build"; then + run_silent "Building Spark..." "build.log" \ + "$SELF/release-build.sh" package +else + echo "Skipping build step." +fi + +if should_build "docs"; then + run_silent "Building documentation..." "docs.log" \ + "$SELF/release-build.sh" docs +else + echo "Skipping docs step." +fi + +if should_build "publish"; then + run_silent "Publishing release" "publish.log" \ + "$SELF/release-build.sh" publish-release +else + echo "Skipping publish step." +fi diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c00b00b845401..24a62a8f4c7d3 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -17,6 +17,9 @@ # limitations under the License. # +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + function exit_with_usage { cat << EOF usage: release-build.sh @@ -87,49 +90,56 @@ NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads BASE_DIR=$(pwd) -MVN="build/mvn --force" - -# Hive-specific profiles for some builds -HIVE_PROFILES="-Phive -Phive-thriftserver" -# Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" -# Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr" -# Scala 2.11 only profiles for some builds -SCALA_2_11_PROFILES="-Pkafka-0-8" -# Scala 2.12 only profiles for some builds -SCALA_2_12_PROFILES="-Pscala-2.12" +init_java +init_maven_sbt rm -rf spark -git clone https://git-wip-us.apache.org/repos/asf/spark.git +git clone "$ASF_REPO" cd spark git checkout $GIT_REF git_hash=`git rev-parse --short HEAD` echo "Checked out Spark git hash $git_hash" if [ -z "$SPARK_VERSION" ]; then - SPARK_VERSION=$($MVN help:evaluate -Dexpression=project.version \ - | grep -v INFO | grep -v WARNING | grep -v Download) + # Run $MVN in a separate command so that 'set -e' does the right thing. + TMP=$(mktemp) + $MVN help:evaluate -Dexpression=project.version > $TMP + SPARK_VERSION=$(cat $TMP | grep -v INFO | grep -v WARNING | grep -v Download) + rm $TMP fi -# Verify we have the right java version set -if [ -z "$JAVA_HOME" ]; then - echo "Please set JAVA_HOME." - exit 1 +# Depending on the version being built, certain extra profiles need to be activated, and +# different versions of Scala are supported. +BASE_PROFILES="-Pmesos -Pyarn" +PUBLISH_SCALA_2_10=0 +SCALA_2_10_PROFILES="-Pscala-2.10" +SCALA_2_11_PROFILES= +SCALA_2_12_PROFILES="-Pscala-2.12" + +if [[ $SPARK_VERSION > "2.3" ]]; then + BASE_PROFILES="$BASE_PROFILES -Pkubernetes -Pflume" + SCALA_2_11_PROFILES="-Pkafka-0-8" +else + PUBLISH_SCALA_2_10=1 fi -java_version=$("${JAVA_HOME}"/bin/javac -version 2>&1 | cut -d " " -f 2) +# Hive-specific profiles for some builds +HIVE_PROFILES="-Phive -Phive-thriftserver" +# Profiles for publishing snapshots and release to Maven Central +PUBLISH_PROFILES="$BASE_PROFILES $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +# Profiles for building binary releases +BASE_RELEASE_PROFILES="$BASE_PROFILES -Psparkr" if [[ ! $SPARK_VERSION < "2.2." ]]; then - if [[ $java_version < "1.8." ]]; then - echo "Java version $java_version is less than required 1.8 for 2.2+" + if [[ $JAVA_VERSION < "1.8." ]]; then + echo "Java version $JAVA_VERSION is less than required 1.8 for 2.2+" echo "Please set JAVA_HOME correctly." exit 1 fi else - if [[ $java_version > "1.7." ]]; then + if ! [[ $JAVA_VERSION =~ 1\.7\..* ]]; then if [ -z "$JAVA_7_HOME" ]; then - echo "Java version $java_version is higher than required 1.7 for pre-2.2" + echo "Java version $JAVA_VERSION is higher than required 1.7 for pre-2.2" echo "Please set JAVA_HOME correctly." exit 1 else @@ -174,8 +184,9 @@ if [[ "$1" == "package" ]]; then FLAGS=$2 ZINC_PORT=$3 BUILD_PACKAGE=$4 - cp -r spark spark-$SPARK_VERSION-bin-$NAME + echo "Building binary dist $NAME" + cp -r spark spark-$SPARK_VERSION-bin-$NAME cd spark-$SPARK_VERSION-bin-$NAME # TODO There should probably be a flag to make-distribution to allow 2.12 support @@ -244,31 +255,39 @@ if [[ "$1" == "package" ]]; then spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 } - # TODO: Check exit codes of children here: - # http://stackoverflow.com/questions/1570262/shell-get-exit-code-of-background-process - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop2.6" "-Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr" & - make_binary_release "hadoop2.7" "-Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip" & - make_binary_release "without-hadoop" "-Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3038" & - wait - rm -rf spark-$SPARK_VERSION-bin-*/ + if ! make_binary_release "hadoop2.6" "$MVN_EXTRA_OPTS -B -Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr"; then + error "Failed to build hadoop2.6 package. Check logs for details." + fi + if ! is_dry_run; then + if ! make_binary_release "hadoop2.7" "$MVN_EXTRA_OPTS -B -Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip"; then + error "Failed to build hadoop2.7 package. Check logs for details." + fi + if ! make_binary_release "without-hadoop" "$MVN_EXTRA_OPTS -B -Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3037"; then + error "Failed to build without-hadoop package. Check logs for details." + fi + fi - svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark - rm -rf "svn-spark/${DEST_DIR_NAME}-bin" - mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + rm -rf spark-$SPARK_VERSION-bin-*/ - echo "Copying release tarballs" - cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" - svn add "svn-spark/${DEST_DIR_NAME}-bin" + if ! is_dry_run; then + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-bin" + mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + + echo "Copying release tarballs" + cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" + svn add "svn-spark/${DEST_DIR_NAME}-bin" + + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" + cd .. + rm -rf svn-spark + fi - cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" - cd .. - rm -rf svn-spark exit 0 fi @@ -282,18 +301,22 @@ if [[ "$1" == "docs" ]]; then cd .. cd .. - svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark - rm -rf "svn-spark/${DEST_DIR_NAME}-docs" - mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" + if ! is_dry_run; then + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-docs" + mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" - echo "Copying release documentation" - cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" - svn add "svn-spark/${DEST_DIR_NAME}-docs" + echo "Copying release documentation" + cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" + svn add "svn-spark/${DEST_DIR_NAME}-docs" - cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" - cd .. - rm -rf svn-spark + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" + cd .. + rm -rf svn-spark + fi + + mv "spark/docs/_site" docs/ exit 0 fi @@ -341,13 +364,15 @@ if [[ "$1" == "publish-release" ]]; then # Using Nexus API documented here: # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API - echo "Creating Nexus staging repository" - repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) - staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") - echo "Created Nexus staging repository: $staged_repo_id" + if ! is_dry_run; then + echo "Creating Nexus staging repository" + repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + fi tmp_repo=$(mktemp -d spark-repo-XXXXX) @@ -356,6 +381,12 @@ if [[ "$1" == "publish-release" ]]; then $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $SCALA_2_11_PROFILES $PUBLISH_PROFILES clean install + if ! is_dry_run && [[ $PUBLISH_SCALA_2_10 = 1 ]]; then + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$((ZINC_PORT + 1)) -Dmaven.repo.local=$tmp_repo -Dscala-2.10 \ + -DskipTests $PUBLISH_PROFILES $SCALA_2_10_PROFILES clean install + fi + #./dev/change-scala-version.sh 2.12 #$MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo \ # -DskipTests $SCALA_2_12_PROFILES §$PUBLISH_PROFILES clean install @@ -371,31 +402,41 @@ if [[ "$1" == "publish-release" ]]; then find . -type f |grep -v \.jar |grep -v \.pom | xargs rm echo "Creating hash and signature files" - # this must have .asc and .sha1 - it really doesn't like anything else there + # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there for file in $(find . -type f) do echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ --detach-sig --armour $file; + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi sha1sum $file | cut -f1 -d' ' > $file.sha1 done - nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id - echo "Uplading files to $nexus_upload" - for file in $(find . -type f) - do - # strip leading ./ - file_short=$(echo $file | sed -e "s/\.\///") - dest_url="$nexus_upload/org/apache/spark/$file_short" - echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url - done + if ! is_dry_run; then + nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id + echo "Uplading files to $nexus_upload" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$nexus_upload/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + fi - echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) - echo "Closed Nexus staging repository: $staged_repo_id" popd rm -rf $tmp_repo cd .. diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index a05716a5f66bb..628bc0504c9c8 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -17,6 +17,9 @@ # limitations under the License. # +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + function exit_with_usage { cat << EOF usage: tag-release.sh @@ -36,6 +39,7 @@ EOF } set -e +set -o pipefail if [[ $@ == *"help"* ]]; then exit_with_usage @@ -54,8 +58,10 @@ for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GI fi done +init_java +init_maven_sbt + ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" -MVN="build/mvn --force" rm -rf spark git clone "https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO" -b $GIT_BRANCH @@ -94,9 +100,15 @@ sed -i".tmp7" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION" git commit -a -m "Preparing development version $NEXT_VERSION" -# Push changes -git push origin $RELEASE_TAG -git push origin HEAD:$GIT_BRANCH - -cd .. -rm -rf spark +if ! is_dry_run; then + # Push changes + git push origin $RELEASE_TAG + git push origin HEAD:$GIT_BRANCH + + cd .. + rm -rf spark +else + cd .. + mv spark spark.tag + echo "Clone with version changes and tag available as spark.tag in the output directory." +fi diff --git a/dev/create-release/release-util.sh b/dev/create-release/release-util.sh new file mode 100644 index 0000000000000..7426b0d6ca08d --- /dev/null +++ b/dev/create-release/release-util.sh @@ -0,0 +1,228 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +DRY_RUN=${DRY_RUN:-0} +GPG="gpg --no-tty --batch" +ASF_REPO="https://git-wip-us.apache.org/repos/asf/spark.git" +ASF_REPO_WEBUI="https://git-wip-us.apache.org/repos/asf?p=spark.git" + +function error { + echo "$*" + exit 1 +} + +function read_config { + local PROMPT="$1" + local DEFAULT="$2" + local REPLY= + + read -p "$PROMPT [$DEFAULT]: " REPLY + local RETVAL="${REPLY:-$DEFAULT}" + if [ -z "$RETVAL" ]; then + error "$PROMPT is must be provided." + fi + echo "$RETVAL" +} + +function parse_version { + grep -e '.*' | \ + head -n 2 | tail -n 1 | cut -d'>' -f2 | cut -d '<' -f1 +} + +function run_silent { + local BANNER="$1" + local LOG_FILE="$2" + shift 2 + + echo "========================" + echo "= $BANNER" + echo "Command: $@" + echo "Log file: $LOG_FILE" + + "$@" 1>"$LOG_FILE" 2>&1 + + local EC=$? + if [ $EC != 0 ]; then + echo "Command FAILED. Check full logs for details." + tail "$LOG_FILE" + exit $EC + fi +} + +function fcreate_secure { + local FPATH="$1" + rm -f "$FPATH" + touch "$FPATH" + chmod 600 "$FPATH" +} + +function check_for_tag { + curl -s --head --fail "$ASF_REPO_WEBUI;a=commit;h=$1" >/dev/null +} + +function get_release_info { + if [ -z "$GIT_BRANCH" ]; then + # If no branch is specified, found out the latest branch from the repo. + GIT_BRANCH=$(git ls-remote --heads "$ASF_REPO" | + grep -v refs/heads/master | + awk '{print $2}' | + sort -r | + head -n 1 | + cut -d/ -f3) + fi + + export GIT_BRANCH=$(read_config "Branch" "$GIT_BRANCH") + + # Find the current version for the branch. + local VERSION=$(curl -s "$ASF_REPO_WEBUI;a=blob_plain;f=pom.xml;hb=refs/heads/$GIT_BRANCH" | + parse_version) + echo "Current branch version is $VERSION." + + if [[ ! $VERSION =~ .*-SNAPSHOT ]]; then + error "Not a SNAPSHOT version: $VERSION" + fi + + NEXT_VERSION="$VERSION" + RELEASE_VERSION="${VERSION/-SNAPSHOT/}" + SHORT_VERSION=$(echo "$VERSION" | cut -d . -f 1-2) + local REV=$(echo "$VERSION" | cut -d . -f 3) + + # Find out what rc is being prepared. + # - If the current version is "x.y.0", then this is rc1 of the "x.y.0" release. + # - If not, need to check whether the previous version has been already released or not. + # - If it has, then we're building rc1 of the current version. + # - If it has not, we're building the next RC of the previous version. + local RC_COUNT + if [ $REV != 0 ]; then + local PREV_REL_REV=$((REV - 1)) + local PREV_REL_TAG="v${SHORT_VERSION}.${PREV_REL_REV}" + if check_for_tag "$PREV_REL_TAG"; then + RC_COUNT=1 + REV=$((REV + 1)) + NEXT_VERSION="${SHORT_VERSION}.${REV}-SNAPSHOT" + else + RELEASE_VERSION="${SHORT_VERSION}.${PREV_REL_REV}" + RC_COUNT=$(git ls-remote --tags "$ASF_REPO" "v${RELEASE_VERSION}-rc*" | wc -l) + RC_COUNT=$((RC_COUNT + 1)) + fi + else + REV=$((REV + 1)) + NEXT_VERSION="${SHORT_VERSION}.${REV}-SNAPSHOT" + RC_COUNT=1 + fi + + export NEXT_VERSION + export RELEASE_VERSION=$(read_config "Release" "$RELEASE_VERSION") + + RC_COUNT=$(read_config "RC #" "$RC_COUNT") + + # Check if the RC already exists, and if re-creating the RC, skip tag creation. + RELEASE_TAG="v${RELEASE_VERSION}-rc${RC_COUNT}" + SKIP_TAG=0 + if check_for_tag "$RELEASE_TAG"; then + read -p "$RELEASE_TAG already exists. Continue anyway [y/n]? " ANSWER + if [ "$ANSWER" != "y" ]; then + error "Exiting." + fi + SKIP_TAG=1 + fi + + + export RELEASE_TAG + + GIT_REF="$RELEASE_TAG" + if is_dry_run; then + echo "This is a dry run. Please confirm the ref that will be built for testing." + GIT_REF=$(read_config "Ref" "$GIT_REF") + fi + export GIT_REF + export SPARK_PACKAGE_VERSION="$RELEASE_TAG" + + # Gather some user information. + export ASF_USERNAME=$(read_config "ASF user" "$LOGNAME") + + GIT_NAME=$(git config user.name || echo "") + export GIT_NAME=$(read_config "Full name" "$GIT_NAME") + + export GIT_EMAIL="$ASF_USERNAME@apache.org" + export GPG_KEY=$(read_config "GPG key" "$GIT_EMAIL") + + cat <&1 | cut -d " " -f 2) + export JAVA_VERSION +} + +# Initializes MVN_EXTRA_OPTS and SBT_OPTS depending on the JAVA_VERSION in use. Requires init_java. +function init_maven_sbt { + MVN="build/mvn -B" + MVN_EXTRA_OPTS= + SBT_OPTS= + if [[ $JAVA_VERSION < "1.8." ]]; then + # Needed for maven central when using Java 7. + SBT_OPTS="-Dhttps.protocols=TLSv1.1,TLSv1.2" + MVN_EXTRA_OPTS="-Dhttps.protocols=TLSv1.1,TLSv1.2" + MVN="$MVN $MVN_EXTRA_OPTS" + fi + export MVN MVN_EXTRA_OPTS SBT_OPTS +} diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 32f6cbb29f0be..ab812e1bb7c04 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -49,13 +49,16 @@ print("Install using 'sudo pip install unidecode'") sys.exit(-1) +if sys.version < '3': + input = raw_input + # Contributors list file name contributors_file_name = "contributors.txt" # Prompt the user to answer yes or no until they do so def yesOrNoPrompt(msg): - response = raw_input("%s [y/n]: " % msg) + response = input("%s [y/n]: " % msg) while response != "y" and response != "n": return yesOrNoPrompt(msg) return response == "y" diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile new file mode 100644 index 0000000000000..07ce320177f5a --- /dev/null +++ b/dev/create-release/spark-rm/Dockerfile @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Image for building Spark releases. Based on Ubuntu 16.04. +# +# Includes: +# * Java 8 +# * Ivy +# * Python/PyPandoc (2.7.12/3.5.2) +# * R-base/R-base-dev (3.3.2+) +# * Ruby 2.3 build utilities + +FROM ubuntu:16.04 + +# These arguments are just for reuse and not really meant to be customized. +ARG APT_INSTALL="apt-get install --no-install-recommends -y" + +ARG BASE_PIP_PKGS="setuptools wheel virtualenv" +ARG PIP_PKGS="pyopenssl pypandoc numpy pygments sphinx" + +# Install extra needed repos and refresh. +# - CRAN repo +# - Ruby repo (for doc generation) +# +# This is all in a single "RUN" command so that if anything changes, "apt update" is run to fetch +# the most current package versions (instead of potentially using old versions cached by docker). +RUN echo 'deb http://cran.cnr.Berkeley.edu/bin/linux/ubuntu xenial/' >> /etc/apt/sources.list && \ + gpg --keyserver keyserver.ubuntu.com --recv-key E084DAB9 && \ + gpg -a --export E084DAB9 | apt-key add - && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean && \ + apt-get update && \ + $APT_INSTALL software-properties-common && \ + apt-add-repository -y ppa:brightbox/ruby-ng && \ + apt-get update && \ + # Install openjdk 8. + $APT_INSTALL openjdk-8-jdk && \ + update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java && \ + # Install build / source control tools + $APT_INSTALL curl wget git maven ivy subversion make gcc lsof libffi-dev \ + pandoc pandoc-citeproc libssl-dev libcurl4-openssl-dev libxml2-dev && \ + ln -s -T /usr/share/java/ivy.jar /usr/share/ant/lib/ivy.jar && \ + curl -sL https://deb.nodesource.com/setup_4.x | bash && \ + $APT_INSTALL nodejs && \ + # Install needed python packages. Use pip for installing packages (for consistency). + $APT_INSTALL libpython2.7-dev libpython3-dev python-pip python3-pip && \ + pip install $BASE_PIP_PKGS && \ + pip install $PIP_PKGS && \ + cd && \ + virtualenv -p python3 p35 && \ + . p35/bin/activate && \ + pip install $BASE_PIP_PKGS && \ + pip install $PIP_PKGS && \ + # Install R packages and dependencies used when building. + # R depends on pandoc*, libssl (which are installed above). + $APT_INSTALL r-base r-base-dev && \ + $APT_INSTALL texlive-latex-base texlive texlive-fonts-extra texinfo qpdf && \ + Rscript -e "install.packages(c('curl', 'xml2', 'httr', 'devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2', 'e1071', 'survival'), repos='http://cran.us.r-project.org/')" && \ + Rscript -e "devtools::install_github('jimhester/lintr')" && \ + # Install tools needed to build the documentation. + $APT_INSTALL ruby2.3 ruby2.3-dev && \ + gem install jekyll --no-rdoc --no-ri && \ + gem install jekyll-redirect-from && \ + gem install pygments.rb + +WORKDIR /opt/spark-rm/output + +ARG UID +RUN useradd -m -s /bin/bash -p spark-rm -u $UID spark-rm +USER spark-rm:spark-rm + +ENTRYPOINT [ "/opt/spark-rm/do-release.sh" ] diff --git a/dev/create-release/vote.tmpl b/dev/create-release/vote.tmpl new file mode 100644 index 0000000000000..2ce953c2f7ec4 --- /dev/null +++ b/dev/create-release/vote.tmpl @@ -0,0 +1,65 @@ +Please vote on releasing the following candidate as Apache Spark version {version}. + +The vote is open until {deadline} and passes if a majority +1 PMC votes are cast, with +a minimum of 3 +1 votes. + +[ ] +1 Release this package as Apache Spark {version} +[ ] -1 Do not release this package because ... + +To learn more about Apache Spark, please see http://spark.apache.org/ + +The tag to be voted on is {tag} (commit {tag_commit}): +https://github.com/apache/spark/tree/{tag} + +The release files, including signatures, digests, etc. can be found at: +https://dist.apache.org/repos/dist/dev/spark/{tag}-bin/ + +Signatures used for Spark RCs can be found in this file: +https://dist.apache.org/repos/dist/dev/spark/KEYS + +The staging repository for this release can be found at: +https://repository.apache.org/content/repositories/orgapachespark-{repo_id}/ + +The documentation corresponding to this release can be found at: +https://dist.apache.org/repos/dist/dev/spark/{tag}-docs/ + +The list of bug fixes going into {version} can be found at the following URL: +https://issues.apache.org/jira/projects/SPARK/versions/{jira_version_id} + +FAQ + +========================= +How can I help test this release? +========================= + +If you are a Spark user, you can help us test this release by taking +an existing Spark workload and running on this release candidate, then +reporting any regressions. + +If you're working in PySpark you can set up a virtual env and install +the current RC and see if anything important breaks, in the Java/Scala +you can add the staging repository to your projects resolvers and test +with the RC (make sure to clean up the artifact cache before/after so +you don't end up building with a out of date RC going forward). + +=========================================== +What should happen to JIRA tickets still targeting {version}? +=========================================== + +The current list of open tickets targeted at {version} can be found at: +{open_issues_link} + +Committers should look at those and triage. Extremely important bug +fixes, documentation, and API tweaks that impact compatibility should +be worked on immediately. Everything else please retarget to an +appropriate release. + +================== +But my bug isn't fixed? +================== + +In order to make timely releases, we will typically not hold the +release unless the bug in question is a regression from the previous +release. That being said, if there is something which is a regression +that has not been correctly targeted please ping me or a committer to +help target the issue. \ No newline at end of file diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index c3d1dd444b506..f50a0aac0aefc 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -157,25 +157,25 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.4.4-nohive.jar +orc-mapreduce-1.4.4-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar @@ -190,9 +190,9 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 290867035f91d..774f9dc39ce4d 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -158,25 +158,25 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.4.4-nohive.jar +orc-mapreduce-1.4.4-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar @@ -191,9 +191,9 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 new file mode 100644 index 0000000000000..19c05ad1e991f --- /dev/null +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -0,0 +1,221 @@ +HikariCP-java7-2.4.12.jar +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +accessors-smart-1.2.jar +activation-1.1.1.jar +aircompressor-0.8.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +antlr4-runtime-4.7.jar +aopalliance-1.0.jar +aopalliance-repackaged-2.4.0-b34.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +arrow-format-0.8.0.jar +arrow-memory-0.8.0.jar +arrow-vector-0.8.0.jar +automaton-1.11-8.jar +avro-1.7.7.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.58.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.11-0.13.2.jar +breeze_2.11-0.13.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.8.4.jar +chill_2.11-0.8.4.jar +commons-beanutils-1.9.3.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-3.0.8.jar +commons-compress-1.4.1.jar +commons-configuration2-2.1.1.jar +commons-crypto-1.0.0.jar +commons-daemon-1.0.13.jar +commons-dbcp-1.4.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.5.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-3.1.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +core-1.1.2.jar +curator-client-2.12.0.jar +curator-framework-2.12.0.jar +curator-recipes-2.12.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.12.1.1.jar +dnsjava-2.1.7.jar +ehcache-3.3.1.jar +eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar +geronimo-jcache_1.0_spec-1.0-alpha-1.jar +gson-2.2.4.jar +guava-14.0.1.jar +guice-4.0.jar +guice-servlet-4.0.jar +hadoop-annotations-3.1.0.jar +hadoop-auth-3.1.0.jar +hadoop-client-3.1.0.jar +hadoop-common-3.1.0.jar +hadoop-hdfs-client-3.1.0.jar +hadoop-mapreduce-client-common-3.1.0.jar +hadoop-mapreduce-client-core-3.1.0.jar +hadoop-mapreduce-client-jobclient-3.1.0.jar +hadoop-yarn-api-3.1.0.jar +hadoop-yarn-client-3.1.0.jar +hadoop-yarn-common-3.1.0.jar +hadoop-yarn-registry-3.1.0.jar +hadoop-yarn-server-common-3.1.0.jar +hadoop-yarn-server-web-proxy-3.1.0.jar +hk2-api-2.4.0-b34.jar +hk2-locator-2.4.0-b34.jar +hk2-utils-2.4.0-b34.jar +hppc-0.7.2.jar +htrace-core4-4.1.0-incubating.jar +httpclient-4.5.4.jar +httpcore-4.4.8.jar +ivy-2.4.0.jar +jackson-annotations-2.6.7.jar +jackson-core-2.6.7.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar +jackson-jaxrs-base-2.7.8.jar +jackson-jaxrs-json-provider-2.7.8.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar +jackson-module-paranamer-2.7.9.jar +jackson-module-scala_2.11-2.6.7.1.jar +janino-3.0.8.jar +java-xmlbuilder-1.1.jar +javassist-3.18.1-GA.jar +javax.annotation-api-1.2.jar +javax.inject-1.jar +javax.inject-2.4.0-b34.jar +javax.servlet-api-3.1.0.jar +javax.ws.rs-api-2.0.1.jar +javolution-5.5.1.jar +jaxb-api-2.2.11.jar +jcip-annotations-1.0-1.jar +jcl-over-slf4j-1.7.16.jar +jdo-api-3.0.1.jar +jersey-client-2.22.2.jar +jersey-common-2.22.2.jar +jersey-container-servlet-2.22.2.jar +jersey-container-servlet-core-2.22.2.jar +jersey-guava-2.22.2.jar +jersey-media-jaxb-2.22.2.jar +jersey-server-2.22.2.jar +jets3t-0.9.4.jar +jetty-webapp-9.3.20.v20170531.jar +jetty-xml-9.3.20.v20170531.jar +jline-2.14.3.jar +joda-time-2.9.3.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-smart-2.3.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar +jsp-api-2.1.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.16.jar +kerb-admin-1.0.1.jar +kerb-client-1.0.1.jar +kerb-common-1.0.1.jar +kerb-core-1.0.1.jar +kerb-crypto-1.0.1.jar +kerb-identity-1.0.1.jar +kerb-server-1.0.1.jar +kerb-simplekdc-1.0.1.jar +kerb-util-1.0.1.jar +kerby-asn1-1.0.1.jar +kerby-config-1.0.1.jar +kerby-pkix-1.0.1.jar +kerby-util-1.0.1.jar +kerby-xdr-1.0.1.jar +kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar +leveldbjni-all-1.8.jar +libfb303-0.9.3.jar +libthrift-0.9.3.jar +log4j-1.2.17.jar +logging-interceptor-3.8.1.jar +lz4-java-1.4.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar +mesos-1.4.0-shaded-protobuf.jar +metrics-core-3.1.5.jar +metrics-graphite-3.1.5.jar +metrics-json-3.1.5.jar +metrics-jvm-3.1.5.jar +minlog-1.3.0.jar +mssql-jdbc-6.2.1.jre7.jar +netty-3.9.9.Final.jar +netty-all-4.1.17.Final.jar +nimbus-jose-jwt-4.41.1.jar +objenesis-2.1.jar +okhttp-2.7.5.jar +okhttp-3.8.1.jar +okio-1.13.0.jar +opencsv-2.3.jar +orc-core-1.4.4-nohive.jar +orc-mapreduce-1.4.4-nohive.jar +oro-2.0.8.jar +osgi-resource-locator-1.0.1.jar +paranamer-2.8.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.10.0.jar +protobuf-java-2.5.0.jar +py4j-0.10.7.jar +pyrolite-4.13.jar +re2j-1.1.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar +scala-xml_2.11-1.0.5.jar +shapeless_2.11-2.3.2.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar +snappy-0.2.jar +snappy-java-1.1.7.1.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar +stax-api-1.0.1.jar +stax2-api-3.1.4.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +token-provider-1.0.1.jar +univocity-parsers-2.6.3.jar +validation-api-1.1.0.Final.jar +woodstox-core-5.0.3.jar +xbean-asm6-shaded-4.8.jar +xz-1.0.jar +zjsonpatch-0.3.0.jar +zookeeper-3.4.9.jar +zstd-jni-1.3.2-2.jar diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 84233c64caa9c..ad99ce55806af 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -211,9 +211,10 @@ mkdir -p "$DISTDIR/examples/src/main" cp -r "$SPARK_HOME/examples/src/main" "$DISTDIR/examples/src/" # Copy license and ASF files -cp "$SPARK_HOME/LICENSE" "$DISTDIR" -cp -r "$SPARK_HOME/licenses" "$DISTDIR" -cp "$SPARK_HOME/NOTICE" "$DISTDIR" +cp "$SPARK_HOME/LICENSE-binary" "$DISTDIR/LICENSE" +mkdir -p "$DISTDIR/licenses" +cp -r "$SPARK_HOME/licenses-binary" "$DISTDIR/licenses" +cp "$SPARK_HOME/NOTICE-binary" "$DISTDIR/NOTICE" if [ -e "$SPARK_HOME/CHANGES.txt" ]; then cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR" diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 5ea205fbed4aa..79c7c021fe74a 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -39,6 +39,9 @@ except ImportError: JIRA_IMPORTED = False +if sys.version < '3': + input = raw_input + # Location of your Spark git development area SPARK_HOME = os.environ.get("SPARK_HOME", os.getcwd()) # Remote name which points to the Gihub site @@ -95,20 +98,21 @@ def run_cmd(cmd): def continue_maybe(prompt): - result = raw_input("\n%s (y/n): " % prompt) + result = input("\n%s (y/n): " % prompt) if result.lower() != "y": fail("Okay, exiting") def clean_up(): - print("Restoring head pointer to %s" % original_head) - run_cmd("git checkout %s" % original_head) + if 'original_head' in globals(): + print("Restoring head pointer to %s" % original_head) + run_cmd("git checkout %s" % original_head) - branches = run_cmd("git branch").replace(" ", "").split("\n") + branches = run_cmd("git branch").replace(" ", "").split("\n") - for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches): - print("Deleting local branch %s" % branch) - run_cmd("git branch -D %s" % branch) + for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches): + print("Deleting local branch %s" % branch) + run_cmd("git branch -D %s" % branch) # merge the requested PR and return the merge hash @@ -133,7 +137,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): '--pretty=format:%an <%ae>']).split("\n") distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True) - primary_author = raw_input( + primary_author = input( "Enter primary author in the format of \"name \" [%s]: " % distinct_authors[0]) if primary_author == "": @@ -183,7 +187,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): def cherry_pick(pr_num, merge_hash, default_branch): - pick_ref = raw_input("Enter a branch name [%s]: " % default_branch) + pick_ref = input("Enter a branch name [%s]: " % default_branch) if pick_ref == "": pick_ref = default_branch @@ -230,7 +234,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): asf_jira = jira.client.JIRA({'server': JIRA_API_BASE}, basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)) - jira_id = raw_input("Enter a JIRA id [%s]: " % default_jira_id) + jira_id = input("Enter a JIRA id [%s]: " % default_jira_id) if jira_id == "": jira_id = default_jira_id @@ -275,7 +279,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): default_fix_versions = filter(lambda x: x != v, default_fix_versions) default_fix_versions = ",".join(default_fix_versions) - fix_versions = raw_input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) + fix_versions = input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) if fix_versions == "": fix_versions = default_fix_versions fix_versions = fix_versions.replace(" ", "").split(",") @@ -314,7 +318,7 @@ def choose_jira_assignee(issue, asf_jira): if author in commentors: annotations.append("Commentor") print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) - raw_assignee = raw_input( + raw_assignee = input( "Enter number of user, or userid, to assign to (blank to leave unassigned):") if raw_assignee == "": return None @@ -427,7 +431,7 @@ def main(): # Assumes branch names can be sorted lexicographically latest_branch = sorted(branch_names, reverse=True)[0] - pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") + pr_num = input("Which pull request would you like to merge? (e.g. 34): ") pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) @@ -439,7 +443,7 @@ def main(): print("I've re-written the title as follows to match the standard format:") print("Original: %s" % pr["title"]) print("Modified: %s" % modified_title) - result = raw_input("Would you like to use the modified title? (y/n): ") + result = input("Would you like to use the modified title? (y/n): ") if result.lower() == "y": title = modified_title print("Using modified title:") @@ -490,7 +494,7 @@ def main(): merge_hash = merge_pr(pr_num, target_ref, title, body, pr_repo_desc) pick_prompt = "Would you like to pick %s into another branch?" % merge_hash - while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": + while input("\n%s (y/n): " % pick_prompt).lower() == "y": merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] if JIRA_IMPORTED: diff --git a/dev/requirements.txt b/dev/requirements.txt index 79782279f8fbd..fa833ab96b8e7 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -2,3 +2,4 @@ jira==1.0.3 PyGithub==1.26.0 Unidecode==0.04.19 pypandoc==1.3.3 +sphinx diff --git a/dev/run-pip-tests b/dev/run-pip-tests index 1321c2be4c192..7271d1014e4ae 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -89,7 +89,7 @@ for python in "${PYTHON_EXECS[@]}"; do source "$VIRTUALENV_PATH"/bin/activate fi # Upgrade pip & friends if using virutal env - if [ ! -n "USE_CONDA" ]; then + if [ ! -n "$USE_CONDA" ]; then pip install --upgrade pip pypandoc wheel numpy fi diff --git a/dev/run-tests.py b/dev/run-tests.py index 164c1e2200aa9..d9d3789ac1255 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -110,7 +110,7 @@ def determine_modules_to_test(changed_modules): ['graphx', 'examples'] >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE - ['sql', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', + ['sql', 'avro', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', 'pyspark-sql', 'repl', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] """ modules_to_test = set() @@ -204,7 +204,7 @@ def run_scala_style_checks(): def run_java_style_checks(): set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE") - run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")]) + run_cmd([os.path.join(SPARK_HOME, "dev", "sbt-checkstyle")]) def run_python_style_checks(): @@ -357,7 +357,7 @@ def build_spark_unidoc_sbt(hadoop_version): exec_sbt(profiles_and_goals) -def build_spark_assembly_sbt(hadoop_version): +def build_spark_assembly_sbt(hadoop_version, checkstyle=False): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["assembly/package"] @@ -366,6 +366,9 @@ def build_spark_assembly_sbt(hadoop_version): " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) + if checkstyle: + run_java_style_checks() + # Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build. # Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the # documentation build fails on a specific machine & environment in Jenkins but it was unable @@ -570,12 +573,13 @@ def main(): or f.endswith("scalastyle-config.xml") for f in changed_files): run_scala_style_checks() + should_run_java_style_checks = False if not changed_files or any(f.endswith(".java") or f.endswith("checkstyle.xml") or f.endswith("checkstyle-suppressions.xml") for f in changed_files): - # run_java_style_checks() - pass + # Run SBT Checkstyle after the build to prevent a side-effect to the build. + should_run_java_style_checks = True if not changed_files or any(f.endswith("lint-python") or f.endswith("tox.ini") or f.endswith(".py") @@ -604,7 +608,7 @@ def main(): detect_binary_inop_with_mima(hadoop_version) # Since we did not build assembly/package before running dev/mima, we need to # do it here because the tests still rely on it; see SPARK-13294 for details. - build_spark_assembly_sbt(hadoop_version) + build_spark_assembly_sbt(hadoop_version, should_run_java_style_checks) # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) diff --git a/dev/sbt-checkstyle b/dev/sbt-checkstyle new file mode 100755 index 0000000000000..8821a7c0e4ccf --- /dev/null +++ b/dev/sbt-checkstyle @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file +# with failure (either resolution or compilation); the "q" makes SBT quit. +ERRORS=$(echo -e "q\n" \ + | build/sbt \ + -Pkinesis-asl \ + -Pmesos \ + -Pkafka-0-8 \ + -Pkubernetes \ + -Pyarn \ + -Pflume \ + -Phive \ + -Phive-thriftserver \ + checkstyle test:checkstyle \ + | awk '{if($1~/error/)print}' \ +) + +if test ! -z "$ERRORS"; then + echo -e "Checkstyle failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Checkstyle checks passed." +fi + diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index dfea762db98c6..2aa355504bf29 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -170,6 +170,16 @@ def __hash__(self): ] ) +avro = Module( + name="avro", + dependencies=[sql], + source_file_regexes=[ + "external/avro", + ], + sbt_test_goals=[ + "avro/test", + ] +) sql_kafka = Module( name="sql-kafka-0-10", diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 3bf7618e1ea96..2fbd6b5e98f7f 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -34,6 +34,7 @@ MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 hadoop-2.7 + hadoop-3.1 ) # We'll switch the version to a temp. one, publish POMs using that new version, then switch back to diff --git a/dev/tox.ini b/dev/tox.ini index 583c1eaaa966b..28dad8f3b5c7c 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -16,4 +16,4 @@ [pycodestyle] ignore=E402,E731,E241,W503,E226,E722,E741,E305 max-line-length=100 -exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/* +exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/*,dist/* diff --git a/docs/building-spark.md b/docs/building-spark.md index 0236bb05849ad..c3bcd90ccc78f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -215,19 +215,23 @@ If you are building Spark for use in a Python environment and you wish to pip in Alternatively, you can also run make-distribution with the --pip option. -## PySpark Tests with Maven +## PySpark Tests with Maven or SBT If you are building PySpark and wish to run the PySpark tests you will need to build Spark with Hive support. ./build/mvn -DskipTests clean package -Phive ./python/run-tests +If you are building PySpark with SBT and wish to run the PySpark tests, you will need to build Spark with Hive support and also build the test components: + + ./build/sbt -Phive clean package + ./build/sbt test:compile + ./python/run-tests + The run-tests script also can be limited to a specific Python version or a specific module ./python/run-tests --python-executables=python --modules=pyspark-sql -**Note:** You can also run Python tests with an sbt build, provided you build Spark with Hive support. - ## Running R Tests To run the SparkR tests you will need to install the [knitr](https://cran.r-project.org/package=knitr), [rmarkdown](https://cran.r-project.org/package=rmarkdown), [testthat](https://cran.r-project.org/package=testthat), [e1071](https://cran.r-project.org/package=e1071) and [survival](https://cran.r-project.org/package=survival) packages first: diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index ac1c336988930..18e8fe77bbdbe 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -70,7 +70,7 @@ be safely used as the direct destination of work with the normal rename-based co ### Installation With the relevant libraries on the classpath and Spark configured with valid credentials, -objects can be can be read or written by using their URLs as the path to data. +objects can be read or written by using their URLs as the path to data. For example `sparkContext.textFile("s3a://landsat-pds/scene_list.gz")` will create an RDD of the file `scene_list.gz` stored in S3, using the s3a connector. @@ -184,7 +184,8 @@ is no need for a workflow of write-then-rename to ensure that files aren't picke while they are still being written. Applications can write straight to the monitored directory. 1. Streams should only be checkpointed to a store implementing a fast and -atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. +atomic `rename()` operation. +Otherwise the checkpointing may be slow and potentially unreliable. ## Further Reading diff --git a/docs/configuration.md b/docs/configuration.md index 4d4d0c58dd07d..0c7c4472be643 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -208,7 +208,7 @@ of the most common options to set are: stored on disk. This should be on a fast, local disk in your system. It can also be a comma-separated list of multiple directories on different disks. - NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone, Mesos) or + NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone), MESOS_SANDBOX (Mesos) or LOCAL_DIRS (YARN) environment variables set by the cluster manager. @@ -328,6 +328,11 @@ Apart from these, the following properties are also available, and may be useful Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory. + + The following symbols, if present will be interpolated: {{APP_ID}} will be replaced by + application ID and {{EXECUTOR_ID}} will be replaced by executor ID. For example, to enable + verbose gc logging to a file named for the executor ID of the app in /tmp, pass a 'value' of: + -verbose:gc -Xloggc:/tmp/{{APP_ID}}-{{EXECUTOR_ID}}.gc @@ -905,8 +910,8 @@ Apart from these, the following properties are also available, and may be useful lz4 The codec used to compress internal data such as RDD partitions, event log, broadcast variables - and shuffle outputs. By default, Spark provides three codecs: lz4, lzf, - and snappy. You can also use fully qualified class names to specify the codec, + and shuffle outputs. By default, Spark provides four codecs: lz4, lzf, + snappy, and zstd. You can also use fully qualified class names to specify the codec, e.g. org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, @@ -1624,9 +1629,10 @@ Apart from these, the following properties are also available, and may be useful spark.blacklist.killBlacklistedExecutors false - (Experimental) If set to "true", allow Spark to automatically kill, and attempt to re-create, - executors when they are blacklisted. Note that, when an entire node is added to the blacklist, - all of the executors on that node will be killed. + (Experimental) If set to "true", allow Spark to automatically kill the executors + when they are blacklisted on fetch failure or blacklisted for the entire application, + as controlled by spark.blacklist.application.*. Note that, when an entire node is added + to the blacklist, all of the executors on that node will be killed. @@ -1753,6 +1759,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.minExecutors, spark.dynamicAllocation.maxExecutors, and spark.dynamicAllocation.initialExecutors + spark.dynamicAllocation.executorAllocationRatio @@ -1797,6 +1804,23 @@ Apart from these, the following properties are also available, and may be useful Lower bound for the number of executors if dynamic allocation is enabled. + + spark.dynamicAllocation.executorAllocationRatio + 1 + + By default, the dynamic allocation will request enough executors to maximize the + parallelism according to the number of tasks to process. While this minimizes the + latency of the job, with small tasks this setting can waste a lot of resources due to + executor allocation overhead, as some executor might not even do any work. + This setting allows to set a ratio that will be used to reduce the number of + executors w.r.t. full parallelism. + Defaults to 1.0 to give maximum parallelism. + 0.5 will divide the target number of executors by 2 + The target number of executors computed by the dynamicAllocation can still be overriden + by the spark.dynamicAllocation.minExecutors and + spark.dynamicAllocation.maxExecutors settings + + spark.dynamicAllocation.schedulerBacklogTimeout 1s diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index d660655e193eb..b3d109039da4d 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -455,11 +455,29 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat ## Naive Bayes [Naive Bayes classifiers](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple -probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence -assumptions between the features. The `spark.ml` implementation currently supports both [multinomial -naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) +probabilistic, multiclass classifiers based on applying Bayes' theorem with strong (naive) independence +assumptions between every pair of features. + +Naive Bayes can be trained very efficiently. With a single pass over the training data, +it computes the conditional probability distribution of each feature given each label. +For prediction, it applies Bayes' theorem to compute the conditional probability distribution +of each label given an observation. + +MLlib supports both [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). -More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib). + +*Input data*: +These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Within that context, each observation is a document and each feature represents a term. +A feature's value is the frequency of the term (in multinomial Naive Bayes) or +a zero or one indicating whether the term was found in the document (in Bernoulli Naive Bayes). +Feature values must be *non-negative*. The model type is selected with an optional parameter +"multinomial" or "bernoulli" with "multinomial" as the default. +For document classification, the input feature vectors should usually be sparse vectors. +Since the training data is only used once, it is not necessary to cache it. + +[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by +setting the parameter $\lambda$ (default to $1.0$). **Examples** diff --git a/docs/ml-features.md b/docs/ml-features.md index 7aed2341584fc..ad6e718b37f1b 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1429,7 +1429,7 @@ for more details on the API. ## Imputer -The `Imputer` transformer completes missing values in a dataset, either using the mean or the +The `Imputer` estimator completes missing values in a dataset, either using the mean or the median of the columns in which the missing values are located. The input columns should be of `DoubleType` or `FloatType`. Currently `Imputer` does not support categorical features and possibly creates incorrect values for columns containing categorical features. Imputer can impute custom values diff --git a/docs/ml-statistics.md b/docs/ml-statistics.md index abfb3cab1e566..6c82b3bb94b24 100644 --- a/docs/ml-statistics.md +++ b/docs/ml-statistics.md @@ -89,4 +89,32 @@ Refer to the [`ChiSquareTest` Python docs](api/python/index.html#pyspark.ml.stat {% include_example python/ml/chi_square_test_example.py %} + + +## Summarizer + +We provide vector column summary statistics for `Dataframe` through `Summarizer`. +Available metrics are the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. + +
    +
    +The following example demonstrates using [`Summarizer`](api/scala/index.html#org.apache.spark.ml.stat.Summarizer$) +to compute the mean and variance for a vector column of the input dataframe, with and without a weight column. + +{% include_example scala/org/apache/spark/examples/ml/SummarizerExample.scala %} +
    + +
    +The following example demonstrates using [`Summarizer`](api/java/org/apache/spark/ml/stat/Summarizer.html) +to compute the mean and variance for a vector column of the input dataframe, with and without a weight column. + +{% include_example java/org/apache/spark/examples/ml/JavaSummarizerExample.java %} +
    + +
    +Refer to the [`Summarizer` Python docs](api/python/index.html#pyspark.ml.stat.Summarizer$) for details on the API. + +{% include_example python/ml/summarizer_example.py %} +
    +
    \ No newline at end of file diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e9e1f3e280609..7149616e534aa 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -140,6 +140,12 @@ namespace as that of the driver and executor pods. For example, to mount a secre --conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets ``` +To use a secret through an environment variable use the following options to the `spark-submit` command: +``` +--conf spark.kubernetes.driver.secretKeyRef.ENV_NAME=name:key +--conf spark.kubernetes.executor.secretKeyRef.ENV_NAME=name:key +``` + ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -264,7 +270,6 @@ future versions of the spark-kubernetes integration. Some of these include: -* PySpark * R * Dynamic Executor Scaling * Local File Dependency Management @@ -321,6 +326,13 @@ specific to Spark on Kubernetes. Container image pull policy used when pulling images within Kubernetes. + + spark.kubernetes.container.image.pullSecrets + + + Comma separated list of Kubernetes secrets used to pull images from private image registries. + + spark.kubernetes.allocation.batch.size 5 @@ -602,4 +614,83 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. + + spark.kubernetes.driver.secretKeyRef.[EnvName] + (none) + + Add as an environment variable to the driver container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, + spark.kubernetes.driver.secretKeyRef.ENV_VAR=spark-secret:key. + + + + spark.kubernetes.executor.secretKeyRef.[EnvName] + (none) + + Add as an environment variable to the executor container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, + spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key. + + + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path + (none) + + Add the Kubernetes Volume named VolumeName of the VolumeType type to the driver pod on the path specified in the value. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly + (none) + + Specify if the mounted volume is read only or not. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false. + + + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].options.[OptionName] + (none) + + Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value, must conform with Kubernetes option format. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. + + + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.path + (none) + + Add the Kubernetes Volume named VolumeName of the VolumeType type to the executor pod on the path specified in the value. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. + + + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.readOnly + false + + Specify if the mounted volume is read only or not. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false. + + + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].options.[OptionName] + (none) + + Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. + + + + spark.kubernetes.memoryOverheadFactor + 0.1 + + This sets the Memory Overhead Factor that will allocate memory to non-JVM memory, which includes off-heap memory allocations, non-JVM tasks, and various systems processes. For JVM-based jobs this value will default to 0.10 and 0.40 for non-JVM jobs. + This is done as non-JVM tasks need more non-JVM heap space and such tasks commonly fail with "Memory Overhead Exceeded" errors. This prempts this error with a higher default. + + + + spark.kubernetes.pyspark.pythonversion + "2" + + This sets the major Python version of the docker image used to run the driver and executor containers. Can either be 2 or 3. + + diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 3c2a1501ca692..66ffb17949845 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -753,6 +753,18 @@ See the [configuration page](configuration.html) for information on Spark config spark.cores.max is reached + + spark.mesos.appJar.local.resolution.mode + host + + Provides support for the `local:///` scheme to reference the app jar resource in cluster mode. + If user uses a local resource (`local:///path/to/jar`) and the config option is not used it defaults to `host` eg. + the mesos fetcher tries to get the resource from the host's file system. + If the value is unknown it prints a warning msg in the dispatcher logs and defaults to `host`. + If the value is `container` then spark submit in the container will use the jar in the container's path: + `/path/to/jar`. + + # Troubleshooting and Debugging diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index ceda8a3ae2403..0b265b0cb1b31 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -133,9 +133,8 @@ To use a custom metrics.properties for the application master and executors, upd spark.yarn.am.waitTime 100s - In cluster mode, time for the YARN Application Master to wait for the - SparkContext to be initialized. In client mode, time for the YARN Application Master to wait - for the driver to connect to it. + Only used in cluster mode. Time for the YARN Application Master to wait for the + SparkContext to be initialized. @@ -219,9 +218,10 @@ To use a custom metrics.properties for the application master and executors, upd spark.yarn.dist.forceDownloadSchemes (none) - Comma-separated list of schemes for which files will be downloaded to the local disk prior to + Comma-separated list of schemes for which resources will be downloaded to the local disk prior to being added to YARN's distributed cache. For use in cases where the YARN service does not - support schemes that are supported by Spark, like http, https and ftp. + support schemes that are supported by Spark, like http, https and ftp, or jars required to be in the + local YARN client's classpath. Wildcard '*' is denoted to download resources for all the schemes. @@ -412,6 +412,16 @@ To use a custom metrics.properties for the application master and executors, upd name matches both the include and the exclude pattern, this file will be excluded eventually. + + spark.yarn.blacklist.executor.launch.blacklisting.enabled + false + + Flag to enable blacklisting of nodes having YARN resource allocation problems. + The error limit for blacklisting can be configured by + spark.blacklist.application.maxFailedExecutorsPerNode. + + + # Important notes @@ -425,9 +435,12 @@ To use a custom metrics.properties for the application master and executors, upd Standard Kerberos support in Spark is covered in the [Security](security.html#kerberos) page. -In YARN mode, when accessing Hadoop file systems, aside from the service hosting the user's home -directory, Spark will also automatically obtain delegation tokens for the service hosting the -staging directory of the Spark application. +In YARN mode, when accessing Hadoop filesystems, Spark will automatically obtain delegation tokens +for: + +- the filesystem hosting the staging directory of the Spark application (which is the default + filesystem if `spark.yarn.stagingDir` is not set); +- if Hadoop federation is enabled, all the federated filesystems in the configuration. If an application needs to interact with other secure Hadoop filesystems, their URIs need to be explicitly provided to Spark at launch time. This is done by listing them in the diff --git a/docs/security.md b/docs/security.md index 8c0c66fb5a285..6ef3a808e0471 100644 --- a/docs/security.md +++ b/docs/security.md @@ -177,7 +177,7 @@ ACLs can be configured for either users or groups. Configuration entries accept lists as input, meaning multiple users or groups can be given the desired privileges. This can be used if you run on a shared cluster and have a set of administrators or developers who need to monitor applications they may not have started themselves. A wildcard (`*`) added to specific ACL -means that all users will have the respective pivilege. By default, only the user submitting the +means that all users will have the respective privilege. By default, only the user submitting the application is added to the ACLs. Group membership is established by using a configurable group mapping provider. The mapper is @@ -446,6 +446,27 @@ replaced with one of the above namespaces. +Spark also supports retrieving `${ns}.keyPassword`, `${ns}.keyStorePassword` and `${ns}.trustStorePassword` from +[Hadoop Credential Providers](https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-common/CredentialProviderAPI.html). +User could store password into credential file and make it accessible by different components, like: + +``` +hadoop credential create spark.ssl.keyPassword -value password \ + -provider jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks +``` + +To configure the location of the credential provider, set the `hadoop.security.credential.provider.path` +config option in the Hadoop configuration used by Spark, like: + +``` + + hadoop.security.credential.provider.path + jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks + +``` + +Or via SparkConf "spark.hadoop.hadoop.security.credential.provider.path=jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks". + ## Preparing the key stores Key stores can be generated by `keytool` program. The reference documentation for this tool for diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index f06e72a387df1..14d742de5655c 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -254,6 +254,18 @@ SPARK_WORKER_OPTS supports the following system properties: especially if you run jobs very frequently. + + spark.storage.cleanupFilesAfterExecutorExit + true + + Enable cleanup non-shuffle files(such as temp. shuffle blocks, cached RDD/broadcast blocks, + spill files, etc) of worker directories following executor exits. Note that this doesn't + overlap with `spark.worker.cleanup.enabled`, as this enables cleanup of non-shuffle files in + local directories of a dead executor, while `spark.worker.cleanup.enabled` enables cleanup of + all files/subdirectories of a stopped and timeout application. + This only affects Standalone mode, support of other cluster manangers can be added in the future. + + spark.worker.ui.compressedLogFileLengthCacheSize 100 diff --git a/docs/sparkr.md b/docs/sparkr.md index 7fabab5d38f16..4faad2c4c1824 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -664,6 +664,6 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. - A warning can be raised if versions of SparkR package and the Spark JVM do not match. -## Upgrading to Spark 2.4.0 +## Upgrading to SparkR 2.3.1 and above - - The `start` parameter of `substr` method was wrongly subtracted by one, previously. In other words, the index specified by `start` parameter was considered as 0-base. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. It has been fixed so the `start` parameter of `substr` method is now 1-base, e.g., therefore to get the same result as `substr(df$a, 2, 5)`, it should be changed to `substr(df$a, 1, 4)`. + - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e8ff1470970f7..ad23dae7c6b7c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -964,7 +964,7 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession Sets the compression codec used when writing Parquet files. If either `compression` or `parquet.compression` is specified in the table-specific options/properties, the precedence would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: - none, uncompressed, snappy, gzip, lzo. + none, uncompressed, snappy, gzip, lzo, brotli, lz4, zstd. @@ -1017,7 +1017,7 @@ the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also Property NameDefaultMeaning spark.sql.orc.impl - hive + native The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1. @@ -1214,7 +1214,7 @@ The following options can be used to configure the version of Hive that is used 1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 2.3.2. + options are 0.12.0 through 2.3.3. @@ -1302,9 +1302,33 @@ the following case-insensitive options: dbtable - The JDBC table that should be read. Note that anything that is valid in a FROM clause of - a SQL query can be used. For example, instead of a full table you could also use a - subquery in parentheses. + The JDBC table that should be read from or written into. Note that when using it in the read + path anything that is valid in a FROM clause of a SQL query can be used. + For example, instead of a full table you could also use a subquery in parentheses. It is not + allowed to specify `dbtable` and `query` options at the same time. + + + + query + + A query that will be used to read data into Spark. The specified query will be parenthesized and used + as a subquery in the FROM clause. Spark will also assign an alias to the subquery clause. + As an example, spark will issue a query of the following form to the JDBC Source.

    + SELECT <columns> FROM (<user_specified_query>) spark_gen_alias

    + Below are couple of restrictions while using this option.
    +
      +
    1. It is not allowed to specify `dbtable` and `query` options at the same time.
    2. +
    3. It is not allowed to spcify `query` and `partitionColumn` options at the same time. When specifying + `partitionColumn` option is required, the subquery can be specified using `dbtable` option instead and + partition columns can be qualified using the subquery alias provided as part of `dbtable`.
      + Example:
      + + spark.read.format("jdbc")
      +    .option("dbtable", "(select c1, c2 from t1) as subq")
      +    .option("partitionColumn", "subq.c1"
      +    .load() +
    4. +
    @@ -1338,6 +1362,17 @@ the following case-insensitive options: + + queryTimeout + + The number of seconds the driver will wait for a Statement object to execute to the given + number of seconds. Zero means there is no limit. In the write path, this option depends on + how JDBC drivers implement the API setQueryTimeout, e.g., the h2 JDBC driver + checks the timeout of each query instead of an entire JDBC batch. + It defaults to 0. + + + fetchsize @@ -1703,7 +1738,7 @@ Using the above optimizations with Arrow will produce the same results as when A enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the DataFrame to the driver program and should be done on a small subset of the data. Not all Spark data types are currently supported and an error can be raised if a column has an unsupported type, -see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +see [Supported SQL Types](#supported-sql-types). If an error occurs during `createDataFrame()`, Spark will fall back to create the DataFrame without Arrow. ## Pandas UDFs (a.k.a. Vectorized UDFs) @@ -1741,6 +1776,11 @@ To use `groupBy().apply()`, the user needs to define the following: * A Python function that defines the computation for each group. * A `StructType` object or a string that defines the schema of the output `DataFrame`. +The column labels of the returned `pandas.DataFrame` must either match the field names in the +defined output schema if specified as strings, or match the field data types by position if not +strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html#pandas.DataFrame) +on how to label columns when constructing a `pandas.DataFrame`. + Note that all data for a group will be loaded into memory before the function is applied. This can lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for [maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user @@ -1805,12 +1845,23 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. + - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. + - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. + - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location. + - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time. + - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. + - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. + - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. + - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. + +## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above + + - As of version 2.3.1 Arrow functionality, including `pandas_udf` and `toPandas()`/`createDataFrame()` with `spark.sql.execution.arrow.enabled` set to `True`, has been marked as experimental. These are still evolving and not currently recommended for use in production. + ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. @@ -1967,6 +2018,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone. + - Un-aliased subquery's semantic has not been well defined with confusing behaviors. Since Spark 2.3, we invalidate such confusing cases, for example: `SELECT v.i from (SELECT i FROM v)`, Spark will throw an analysis exception in this case because users should not be able to use the qualifier inside a subquery. See [SPARK-20690](https://issues.apache.org/jira/browse/SPARK-20690) and [SPARK-21335](https://issues.apache.org/jira/browse/SPARK-21335) for more details. ## Upgrading From Spark SQL 2.1 to 2.2 @@ -2233,7 +2285,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.3. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index c30959263cdfa..118b05355c74d 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2176,6 +2176,8 @@ the input data stream (using `inputStream.repartition()`). This distributes the received batches of data across the specified number of machines in the cluster before further processing. +For direct stream, please refer to [Spark Streaming + Kafka Integration Guide](streaming-kafka-integration.html) + ### Level of Parallelism in Data Processing {:.no_toc} Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 602a4c70848e7..0842e8dd88672 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -926,7 +926,7 @@ event time. For a specific window starting at time `T`, the engine will maintain data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, but data later than the threshold will start getting dropped -(see [later]((#semantic-guarantees-of-aggregation-with-watermarking)) +(see [later](#semantic-guarantees-of-aggregation-with-watermarking) in the section for the exact guarantees). Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below. diff --git a/docs/tuning.md b/docs/tuning.md index 912c39879be8f..1c3bd0e8758ff 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -132,7 +132,7 @@ The best way to size the amount of memory consumption a dataset will require is into cache, and look at the "Storage" page in the web UI. The page will tell you how much memory the RDD is occupying. -To estimate the memory consumption of a particular object, use `SizeEstimator`'s `estimate` method +To estimate the memory consumption of a particular object, use `SizeEstimator`'s `estimate` method. This is useful for experimenting with different data layouts to trim memory usage, as well as determining the amount of space a broadcast variable will occupy on each executor heap. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java new file mode 100644 index 0000000000000..51865637df6f6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.clustering.PowerIterationClustering; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPowerIterationClusteringExample { + public static void main(String[] args) { + // Create a SparkSession. + SparkSession spark = SparkSession + .builder() + .appName("JavaPowerIterationClustering") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(0L, 1L, 1.0), + RowFactory.create(0L, 2L, 1.0), + RowFactory.create(1L, 2L, 1.0), + RowFactory.create(3L, 4L, 1.0), + RowFactory.create(4L, 0L, 0.1) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("src", DataTypes.LongType, false, Metadata.empty()), + new StructField("dst", DataTypes.LongType, false, Metadata.empty()), + new StructField("weight", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + PowerIterationClustering model = new PowerIterationClustering() + .setK(2) + .setMaxIter(10) + .setInitMode("degree") + .setWeightCol("weight"); + + Dataset result = model.assignClusters(df); + result.show(false); + // $example off$ + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java new file mode 100644 index 0000000000000..e9b84365d86ed --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.*; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.stat.Summarizer; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaSummarizerExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaSummarizerExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(Vectors.dense(2.0, 3.0, 5.0), 1.0), + RowFactory.create(Vectors.dense(4.0, 6.0, 7.0), 2.0) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("weight", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + Row result1 = df.select(Summarizer.metrics("mean", "variance") + .summary(new Column("features"), new Column("weight")).as("summary")) + .select("summary.mean", "summary.variance").first(); + System.out.println("with weight: mean = " + result1.getAs(0).toString() + + ", variance = " + result1.getAs(1).toString()); + + Row result2 = df.select( + Summarizer.mean(new Column("features")), + Summarizer.variance(new Column("features")) + ).first(); + System.out.println("without weight: mean = " + result2.getAs(0).toString() + + ", variance = " + result2.getAs(1).toString()); + // $example off$ + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java index b6b163fa8b2cd..748bf58f30350 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -26,7 +26,9 @@ import scala.Tuple2; +import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.serialization.StringDeserializer; import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.*; @@ -37,30 +39,33 @@ /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: JavaDirectKafkaWordCount + * Usage: JavaDirectKafkaWordCount * is a list of one or more Kafka brokers + * is a consumer group name to consume from topics * is a list of one or more kafka topics to consume from * * Example: * $ bin/run-example streaming.JavaDirectKafkaWordCount broker1-host:port,broker2-host:port \ - * topic1,topic2 + * consumer-group topic1,topic2 */ public final class JavaDirectKafkaWordCount { private static final Pattern SPACE = Pattern.compile(" "); public static void main(String[] args) throws Exception { - if (args.length < 2) { - System.err.println("Usage: JavaDirectKafkaWordCount \n" + - " is a list of one or more Kafka brokers\n" + - " is a list of one or more kafka topics to consume from\n\n"); + if (args.length < 3) { + System.err.println("Usage: JavaDirectKafkaWordCount \n" + + " is a list of one or more Kafka brokers\n" + + " is a consumer group name to consume from topics\n" + + " is a list of one or more kafka topics to consume from\n\n"); System.exit(1); } StreamingExamples.setStreamingLogLevels(); String brokers = args[0]; - String topics = args[1]; + String groupId = args[1]; + String topics = args[2]; // Create context with a 2 seconds batch interval SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); @@ -68,7 +73,10 @@ public static void main(String[] args) throws Exception { Set topicsSet = new HashSet<>(Arrays.asList(topics.split(","))); Map kafkaParams = new HashMap<>(); - kafkaParams.put("metadata.broker.list", brokers); + kafkaParams.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokers); + kafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, groupId); + kafkaParams.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + kafkaParams.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); // Create direct kafka stream with brokers and topics JavaInputDStream> messages = KafkaUtils.createDirectStream( diff --git a/examples/src/main/python/ml/summarizer_example.py b/examples/src/main/python/ml/summarizer_example.py new file mode 100644 index 0000000000000..8835f189a1ad4 --- /dev/null +++ b/examples/src/main/python/ml/summarizer_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example for summarizer. +Run with: + bin/spark-submit examples/src/main/python/ml/summarizer_example.py +""" +from __future__ import print_function + +from pyspark.sql import SparkSession +# $example on$ +from pyspark.ml.stat import Summarizer +from pyspark.sql import Row +from pyspark.ml.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("SummarizerExample") \ + .getOrCreate() + sc = spark.sparkContext + + # $example on$ + df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), + Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() + + # create summarizer for multiple metrics "mean" and "count" + summarizer = Summarizer.metrics("mean", "count") + + # compute statistics for multiple metrics with weight + df.select(summarizer.summary(df.features, df.weight)).show(truncate=False) + + # compute statistics for multiple metrics without weight + df.select(summarizer.summary(df.features)).show(truncate=False) + + # compute statistics for single metric "mean" with weight + df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False) + + # compute statistics for single metric "mean" without weight + df.select(Summarizer.mean(df.features)).show(truncate=False) + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/py_container_checks.py b/examples/src/main/python/py_container_checks.py new file mode 100644 index 0000000000000..f6b3be2806c82 --- /dev/null +++ b/examples/src/main/python/py_container_checks.py @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys + + +def version_check(python_env, major_python_version): + """ + These are various tests to test the Python container image. + This file will be distributed via --py-files in the e2e tests. + """ + env_version = os.environ.get('PYSPARK_PYTHON') + print("Python runtime version check is: " + + str(sys.version_info[0] == major_python_version)) + + print("Python environment version check is: " + + str(env_version == python_env)) diff --git a/examples/src/main/python/pyfiles.py b/examples/src/main/python/pyfiles.py new file mode 100644 index 0000000000000..4193654b49a12 --- /dev/null +++ b/examples/src/main/python/pyfiles.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark.sql import SparkSession + + +if __name__ == "__main__": + """ + Usage: pyfiles [major_python_version] + """ + spark = SparkSession \ + .builder \ + .appName("PyFilesTest") \ + .getOrCreate() + + from py_container_checks import version_check + # Begin of Python container checks + version_check(sys.argv[1], 2 if sys.argv[1] == "python" else 3) + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala new file mode 100644 index 0000000000000..ca8f7affb14e8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.clustering.PowerIterationClustering +// $example off$ +import org.apache.spark.sql.SparkSession + +object PowerIterationClusteringExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + + // $example on$ + val dataset = spark.createDataFrame(Seq( + (0L, 1L, 1.0), + (0L, 2L, 1.0), + (1L, 2L, 1.0), + (3L, 4L, 1.0), + (4L, 0L, 0.1) + )).toDF("src", "dst", "weight") + + val model = new PowerIterationClustering(). + setK(2). + setMaxIter(20). + setInitMode("degree"). + setWeightCol("weight") + + val prediction = model.assignClusters(dataset).select("id", "cluster") + + // Shows the cluster assignment + prediction.show(false) + // $example off$ + + spark.stop() + } + } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala new file mode 100644 index 0000000000000..2f54d1d81bc48 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.stat.Summarizer +// $example off$ +import org.apache.spark.sql.SparkSession + +object SummarizerExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("SummarizerExample") + .getOrCreate() + + import spark.implicits._ + import Summarizer._ + + // $example on$ + val data = Seq( + (Vectors.dense(2.0, 3.0, 5.0), 1.0), + (Vectors.dense(4.0, 6.0, 7.0), 2.0) + ) + + val df = data.toDF("features", "weight") + + val (meanVal, varianceVal) = df.select(metrics("mean", "variance") + .summary($"features", $"weight").as("summary")) + .select("summary.mean", "summary.variance") + .as[(Vector, Vector)].first() + + println(s"with weight: mean = ${meanVal}, variance = ${varianceVal}") + + val (meanVal2, varianceVal2) = df.select(mean($"features"), variance($"features")) + .as[(Vector, Vector)].first() + + println(s"without weight: mean = ${meanVal2}, sum = ${varianceVal2}") + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/external/avro/pom.xml b/external/avro/pom.xml new file mode 100644 index 0000000000000..42e865bc38824 --- /dev/null +++ b/external/avro/pom.xml @@ -0,0 +1,73 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.4.0-SNAPSHOT + ../../pom.xml + + + spark-sql-avro_2.11 + + avro + + jar + Spark Avro + http://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..95835f0d4ca49 --- /dev/null +++ b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.avro.AvroFileFormat diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala new file mode 100644 index 0000000000000..b31149a2c74c2 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic._ +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A deserializer to deserialize data in avro format to data in catalyst format. + */ +class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { + private val converter: Any => Any = rootCatalystType match { + // A shortcut for empty schema. + case st: StructType if st.isEmpty => + (data: Any) => InternalRow.empty + + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + val fieldUpdater = new RowUpdater(resultRow) + val writer = getRecordWriter(rootAvroType, st, Nil) + (data: Any) => { + val record = data.asInstanceOf[GenericRecord] + writer(fieldUpdater, record) + resultRow + } + + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val fieldUpdater = new RowUpdater(tmpRow) + val writer = newWriter(rootAvroType, rootCatalystType, Nil) + (data: Any) => { + writer(fieldUpdater, 0, data) + tmpRow.get(0, rootCatalystType) + } + } + + def deserialize(data: Any): Any = converter(data) + + /** + * Creates a writer to write avro values to Catalyst values at the given ordinal with the given + * updater. + */ + private def newWriter( + avroType: Schema, + catalystType: DataType, + path: List[String]): (CatalystDataUpdater, Int, Any) => Unit = + (avroType.getType, catalystType) match { + case (NULL, NullType) => (updater, ordinal, _) => + updater.setNullAt(ordinal) + + // TODO: we can avoid boxing if future version of avro provide primitive accessors. + case (BOOLEAN, BooleanType) => (updater, ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (INT, IntegerType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (LONG, LongType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (LONG, TimestampType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + + case (LONG, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt) + + case (FLOAT, FloatType) => (updater, ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (DOUBLE, DoubleType) => (updater, ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (STRING, StringType) => (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + case s: Utf8 => + val bytes = new Array[Byte](s.getByteLength) + System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength) + UTF8String.fromBytes(bytes) + } + updater.set(ordinal, str) + + case (ENUM, StringType) => (updater, ordinal, value) => + updater.set(ordinal, UTF8String.fromString(value.toString)) + + case (FIXED, BinaryType) => (updater, ordinal, value) => + updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone()) + + case (BYTES, BinaryType) => (updater, ordinal, value) => + val bytes = value match { + case b: ByteBuffer => + val bytes = new Array[Byte](b.remaining) + b.get(bytes) + bytes + case b: Array[Byte] => b + case other => throw new RuntimeException(s"$other is not a valid avro binary.") + + } + updater.set(ordinal, bytes) + + case (RECORD, st: StructType) => + val writeRecord = getRecordWriter(avroType, st, path) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord]) + updater.set(ordinal, row) + + case (ARRAY, ArrayType(elementType, containsNull)) => + val elementWriter = newWriter(avroType.getElementType, elementType, path) + (updater, ordinal, value) => + val array = value.asInstanceOf[GenericData.Array[Any]] + val len = array.size() + val result = createArrayData(elementType, len) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + while (i < len) { + val element = array.get(i) + if (element == null) { + if (!containsNull) { + throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + + case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType => + val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, path) + val valueWriter = newWriter(avroType.getValueType, valueType, path) + (updater, ordinal, value) => + val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]] + val keyArray = createArrayData(keyType, map.size()) + val keyUpdater = new ArrayDataUpdater(keyArray) + val valueArray = createArrayData(valueType, map.size()) + val valueUpdater = new ArrayDataUpdater(valueArray) + val iter = map.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + assert(entry.getKey != null) + keyWriter(keyUpdater, i, entry.getKey) + if (entry.getValue == null) { + if (!valueContainsNull) { + throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + valueUpdater.setNullAt(i) + } + } else { + valueWriter(valueUpdater, i, entry.getValue) + } + i += 1 + } + + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case (UNION, _) => + val allTypes = avroType.getTypes.asScala + val nonNullTypes = allTypes.filter(_.getType != NULL) + if (nonNullTypes.nonEmpty) { + if (nonNullTypes.length == 1) { + newWriter(nonNullTypes.head, catalystType, path) + } else { + nonNullTypes.map(_.getType) match { + case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case l: java.lang.Long => updater.setLong(ordinal, l) + case i: java.lang.Integer => updater.setLong(ordinal, i.longValue()) + } + + case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case d: java.lang.Double => updater.setDouble(ordinal, d) + case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue()) + } + + case _ => + catalystType match { + case st: StructType if st.length == nonNullTypes.size => + val fieldWriters = nonNullTypes.zip(st.fields).map { + case (schema, field) => newWriter(schema, field.dataType, path :+ field.name) + }.toArray + (updater, ordinal, value) => { + val row = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(row) + val i = GenericData.get().resolveUnion(avroType, value) + fieldWriters(i)(fieldUpdater, i, value) + updater.set(ordinal, row) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path " + + s"${path.mkString(".")} is not compatible " + + s"(avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + } + } + } else { + (updater, ordinal, value) => updater.setNullAt(ordinal) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path ${path.mkString(".")} " + + s"is not compatible (avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + + private def getRecordWriter( + avroType: Schema, + sqlType: StructType, + path: List[String]): (CatalystDataUpdater, GenericRecord) => Unit = { + val validFieldIndexes = ArrayBuffer.empty[Int] + val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit] + + val length = sqlType.length + var i = 0 + while (i < length) { + val sqlField = sqlType.fields(i) + val avroField = avroType.getField(sqlField.name) + if (avroField != null) { + validFieldIndexes += avroField.pos() + + val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name) + val ordinal = i + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) + } + } + fieldWriters += fieldWriter + } else if (!sqlField.nullable) { + throw new IncompatibleSchemaException( + s""" + |Cannot find non-nullable field ${path.mkString(".")}.${sqlField.name} in Avro schema. + |Source Avro schema: $rootAvroType. + |Target Catalyst type: $rootCatalystType. + """.stripMargin) + } + i += 1 + } + + (fieldUpdater, record) => { + var i = 0 + while (i < validFieldIndexes.length) { + fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i))) + i += 1 + } + } + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala new file mode 100755 index 0000000000000..fb93033bb15d4 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.net.URI +import java.util.zip.Deflater + +import scala.util.control.NonFatal + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.avro.Schema +import org.apache.avro.file.{DataFileConstants, DataFileReader} +import org.apache.avro.generic.{GenericDatumReader, GenericRecord} +import org.apache.avro.mapred.{AvroOutputFormat, FsInput} +import org.apache.avro.mapreduce.AvroJob +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.Job +import org.slf4j.LoggerFactory + +import org.apache.spark.TaskContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.types.StructType + +private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { + private val log = LoggerFactory.getLogger(getClass) + + override def equals(other: Any): Boolean = other match { + case _: AvroFileFormat => true + case _ => false + } + + // Dummy hashCode() to appease ScalaStyle. + override def hashCode(): Int = super.hashCode() + + override def inferSchema( + spark: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val conf = spark.sparkContext.hadoopConfiguration + + // Schema evolution is not supported yet. Here we only pick a single random sample file to + // figure out the schema of the whole dataset. + val sampleFile = + if (conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true)) { + files.find(_.getPath.getName.endsWith(".avro")).getOrElse { + throw new FileNotFoundException( + "No Avro files found. Hadoop option \"avro.mapred.ignore.inputs.without.extension\" " + + " is set to true. Do all input files have \".avro\" extension?" + ) + } + } else { + files.headOption.getOrElse { + throw new FileNotFoundException("No Avro files found.") + } + } + + // User can specify an optional avro json schema. + val avroSchema = options.get(AvroFileFormat.AvroSchema) + .map(new Schema.Parser().parse) + .getOrElse { + val in = new FsInput(sampleFile.getPath, conf) + try { + val reader = DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]()) + try { + reader.getSchema + } finally { + reader.close() + } + } finally { + in.close() + } + } + + SchemaConverters.toSqlType(avroSchema).dataType match { + case t: StructType => Some(t) + case _ => throw new RuntimeException( + s"""Avro schema cannot be converted to a Spark SQL StructType: + | + |${avroSchema.toString(true)} + |""".stripMargin) + } + } + + override def shortName(): String = "avro" + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = true + + override def prepareWrite( + spark: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val recordName = options.getOrElse("recordName", "topLevelRecord") + val recordNamespace = options.getOrElse("recordNamespace", "") + val outputAvroSchema = SchemaConverters.toAvroType( + dataSchema, nullable = false, recordName, recordNamespace) + + AvroJob.setOutputKeySchema(job, outputAvroSchema) + val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" + val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" + val COMPRESS_KEY = "mapred.output.compress" + + spark.conf.get(AVRO_COMPRESSION_CODEC, "snappy") match { + case "uncompressed" => + log.info("writing uncompressed Avro records") + job.getConfiguration.setBoolean(COMPRESS_KEY, false) + + case "snappy" => + log.info("compressing Avro output using Snappy") + job.getConfiguration.setBoolean(COMPRESS_KEY, true) + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.SNAPPY_CODEC) + + case "deflate" => + val deflateLevel = spark.conf.get( + AVRO_DEFLATE_LEVEL, Deflater.DEFAULT_COMPRESSION.toString).toInt + log.info(s"compressing Avro output using deflate (level=$deflateLevel)") + job.getConfiguration.setBoolean(COMPRESS_KEY, true) + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.DEFLATE_CODEC) + job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) + + case unknown: String => + log.error(s"unsupported compression codec $unknown") + } + + new AvroOutputWriterFactory(dataSchema, new SerializableSchema(outputAvroSchema)) + } + + override def buildReader( + spark: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + + val broadcastedConf = + spark.sparkContext.broadcast(new AvroFileFormat.SerializableConfiguration(hadoopConf)) + + (file: PartitionedFile) => { + val log = LoggerFactory.getLogger(classOf[AvroFileFormat]) + val conf = broadcastedConf.value.value + val userProvidedSchema = options.get(AvroFileFormat.AvroSchema).map(new Schema.Parser().parse) + + // TODO Removes this check once `FileFormat` gets a general file filtering interface method. + // Doing input file filtering is improper because we may generate empty tasks that process no + // input files but stress the scheduler. We should probably add a more general input file + // filtering mechanism for `FileFormat` data sources. See SPARK-16317. + if ( + conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true) && + !file.filePath.endsWith(".avro") + ) { + Iterator.empty + } else { + val reader = { + val in = new FsInput(new Path(new URI(file.filePath)), conf) + try { + val datumReader = userProvidedSchema match { + case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema) + case _ => new GenericDatumReader[GenericRecord]() + } + DataFileReader.openReader(in, datumReader) + } catch { + case NonFatal(e) => + log.error("Exception while opening DataFileReader", e) + in.close() + throw e + } + } + + // Ensure that the reader is closed even if the task fails or doesn't consume the entire + // iterator of records. + Option(TaskContext.get()).foreach { taskContext => + taskContext.addTaskCompletionListener { _ => + reader.close() + } + } + + reader.sync(file.start) + val stop = file.start + file.length + + val deserializer = + new AvroDeserializer(userProvidedSchema.getOrElse(reader.getSchema), requiredSchema) + + new Iterator[InternalRow] { + private[this] var completed = false + + override def hasNext: Boolean = { + if (completed) { + false + } else { + val r = reader.hasNext && !reader.pastSync(stop) + if (!r) { + reader.close() + completed = true + } + r + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException("next on empty iterator") + } + val record = reader.next() + deserializer.deserialize(record).asInstanceOf[InternalRow] + } + } + } + } + } +} + +private[avro] object AvroFileFormat { + val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension" + + val AvroSchema = "avroSchema" + + class SerializableConfiguration(@transient var value: Configuration) + extends Serializable with KryoSerializable { + @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) + + private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { + out.defaultWriteObject() + value.write(out) + } + + private def readObject(in: ObjectInputStream): Unit = tryOrIOException { + value = new Configuration(false) + value.readFields(in) + } + + private def tryOrIOException[T](block: => T): T = { + try { + block + } catch { + case e: IOException => + log.error("Exception encountered", e) + throw e + case NonFatal(e) => + log.error("Exception encountered", e) + throw new IOException(e) + } + } + + def write(kryo: Kryo, out: Output): Unit = { + val dos = new DataOutputStream(out) + value.write(dos) + dos.flush() + } + + def read(kryo: Kryo, in: Input): Unit = { + value = new Configuration(false) + value.readFields(new DataInputStream(in)) + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala new file mode 100644 index 0000000000000..06507115f5ed8 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.{IOException, OutputStream} + +import org.apache.avro.Schema +import org.apache.avro.generic.GenericRecord +import org.apache.avro.mapred.AvroKey +import org.apache.avro.mapreduce.AvroKeyOutputFormat +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.OutputWriter +import org.apache.spark.sql.types._ + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[avro] class AvroOutputWriter( + path: String, + context: TaskAttemptContext, + schema: StructType, + avroSchema: Schema) extends OutputWriter { + + // The input rows will never be null. + private lazy val serializer = new AvroSerializer(schema, avroSchema, nullable = false) + + /** + * Overrides the couple of methods responsible for generating the output streams / files so + * that the data can be correctly partitioned + */ + private val recordWriter: RecordWriter[AvroKey[GenericRecord], NullWritable] = + new AvroKeyOutputFormat[GenericRecord]() { + + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } + + @throws(classOf[IOException]) + override def getAvroFileOutputStream(c: TaskAttemptContext): OutputStream = { + val path = getDefaultWorkFile(context, ".avro") + path.getFileSystem(context.getConfiguration).create(path) + } + + }.getRecordWriter(context) + + override def write(row: InternalRow): Unit = { + val key = new AvroKey(serializer.serialize(row).asInstanceOf[GenericRecord]) + recordWriter.write(key, NullWritable.get()) + } + + override def close(): Unit = recordWriter.close(context) +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala new file mode 100644 index 0000000000000..18a6d93951408 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.types.StructType + +private[avro] class AvroOutputWriterFactory( + schema: StructType, + avroSchema: SerializableSchema) extends OutputWriterFactory { + + override def getFileExtension(context: TaskAttemptContext): String = ".avro" + + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new AvroOutputWriter(path, context, schema, avroSchema.value) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala new file mode 100644 index 0000000000000..2b4c5813a535b --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.Schema.Type.NULL +import org.apache.avro.generic.GenericData.Record +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +/** + * A serializer to serialize data in catalyst format to data in avro format. + */ +class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) { + + def serialize(catalystData: Any): Any = { + converter.apply(catalystData) + } + + private val converter: Any => Any = { + val actualAvroType = resolveNullableType(rootAvroType, nullable) + val baseConverter = rootCatalystType match { + case st: StructType => + newStructConverter(st, actualAvroType).asInstanceOf[Any => Any] + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val converter = newConverter(rootCatalystType, actualAvroType) + (data: Any) => + tmpRow.update(0, data) + converter.apply(tmpRow, 0) + } + if (nullable) { + (data: Any) => + if (data == null) { + null + } else { + baseConverter.apply(data) + } + } else { + baseConverter + } + } + + private type Converter = (SpecializedGetters, Int) => Any + + private def newConverter(catalystType: DataType, avroType: Schema): Converter = { + catalystType match { + case NullType => + (getter, ordinal) => null + case BooleanType => + (getter, ordinal) => getter.getBoolean(ordinal) + case ByteType => + (getter, ordinal) => getter.getByte(ordinal).toInt + case ShortType => + (getter, ordinal) => getter.getShort(ordinal).toInt + case IntegerType => + (getter, ordinal) => getter.getInt(ordinal) + case LongType => + (getter, ordinal) => getter.getLong(ordinal) + case FloatType => + (getter, ordinal) => getter.getFloat(ordinal) + case DoubleType => + (getter, ordinal) => getter.getDouble(ordinal) + case d: DecimalType => + (getter, ordinal) => getter.getDecimal(ordinal, d.precision, d.scale).toString + case StringType => + (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) + case BinaryType => + (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + case DateType => + (getter, ordinal) => getter.getInt(ordinal) * DateTimeUtils.MILLIS_PER_DAY + case TimestampType => + (getter, ordinal) => getter.getLong(ordinal) / 1000 + + case ArrayType(et, containsNull) => + val elementConverter = newConverter( + et, resolveNullableType(avroType.getElementType, containsNull)) + (getter, ordinal) => { + val arrayData = getter.getArray(ordinal) + val result = new java.util.ArrayList[Any] + var i = 0 + while (i < arrayData.numElements()) { + if (arrayData.isNullAt(i)) { + result.add(null) + } else { + result.add(elementConverter(arrayData, i)) + } + i += 1 + } + result + } + + case st: StructType => + val structConverter = newStructConverter(st, avroType) + val numFields = st.length + (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) + + case MapType(kt, vt, valueContainsNull) if kt == StringType => + val valueConverter = newConverter( + vt, resolveNullableType(avroType.getValueType, valueContainsNull)) + (getter, ordinal) => + val mapData = getter.getMap(ordinal) + val result = new java.util.HashMap[String, Any](mapData.numElements()) + val keyArray = mapData.keyArray() + val valueArray = mapData.valueArray() + var i = 0 + while (i < mapData.numElements()) { + val key = keyArray.getUTF8String(i).toString + if (valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result + + case other => + throw new IncompatibleSchemaException(s"Unexpected type: $other") + } + } + + private def newStructConverter( + catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = { + val avroFields = avroStruct.getFields + assert(avroFields.size() == catalystStruct.length) + val fieldConverters = catalystStruct.zip(avroFields.asScala).map { + case (f1, f2) => newConverter(f1.dataType, resolveNullableType(f2.schema(), f1.nullable)) + } + val numFields = catalystStruct.length + (row: InternalRow) => + val result = new Record(avroStruct) + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + result.put(i, null) + } else { + result.put(i, fieldConverters(i).apply(row, i)) + } + i += 1 + } + result + } + + private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = { + if (nullable) { + // avro uses union to represent nullable type. + val fields = avroType.getTypes.asScala + assert(fields.length == 2) + val actualType = fields.filter(_.getType != NULL) + assert(actualType.length == 1) + actualType.head + } else { + avroType + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala new file mode 100644 index 0000000000000..87fae63aeff2b --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import scala.collection.JavaConverters._ + +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.Schema.Type._ + +import org.apache.spark.sql.types._ + +/** + * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice + * versa. + */ +object SchemaConverters { + case class SchemaType(dataType: DataType, nullable: Boolean) + + /** + * This function takes an avro schema and returns a sql schema. + */ + def toSqlType(avroSchema: Schema): SchemaType = { + avroSchema.getType match { + case INT => SchemaType(IntegerType, nullable = false) + case STRING => SchemaType(StringType, nullable = false) + case BOOLEAN => SchemaType(BooleanType, nullable = false) + case BYTES => SchemaType(BinaryType, nullable = false) + case DOUBLE => SchemaType(DoubleType, nullable = false) + case FLOAT => SchemaType(FloatType, nullable = false) + case LONG => SchemaType(LongType, nullable = false) + case FIXED => SchemaType(BinaryType, nullable = false) + case ENUM => SchemaType(StringType, nullable = false) + + case RECORD => + val fields = avroSchema.getFields.asScala.map { f => + val schemaType = toSqlType(f.schema()) + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + + SchemaType(StructType(fields), nullable = false) + + case ARRAY => + val schemaType = toSqlType(avroSchema.getElementType) + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + + case MAP => + val schemaType = toSqlType(avroSchema.getValueType) + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + + case UNION => + if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { + // In case of a union with null, eliminate it and make a recursive call + val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + toSqlType(remainingUnionTypes.head).copy(nullable = true) + } else { + toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true) + } + } else avroSchema.getTypes.asScala.map(_.getType) match { + case Seq(t1) => + toSqlType(avroSchema.getTypes.get(0)) + case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => + SchemaType(LongType, nullable = false) + case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => + SchemaType(DoubleType, nullable = false) + case _ => + // Convert complex unions to struct types where field names are member0, member1, etc. + // This is consistent with the behavior when converting between Avro and Parquet. + val fields = avroSchema.getTypes.asScala.zipWithIndex.map { + case (s, i) => + val schemaType = toSqlType(s) + // All fields are nullable because only one of them is set at a time + StructField(s"member$i", schemaType.dataType, nullable = true) + } + + SchemaType(StructType(fields), nullable = false) + } + + case other => throw new IncompatibleSchemaException(s"Unsupported type $other") + } + } + + def toAvroType( + catalystType: DataType, + nullable: Boolean = false, + recordName: String = "topLevelRecord", + prevNameSpace: String = ""): Schema = { + val builder = if (nullable) { + SchemaBuilder.builder().nullable() + } else { + SchemaBuilder.builder() + } + catalystType match { + case BooleanType => builder.booleanType() + case ByteType | ShortType | IntegerType => builder.intType() + case LongType => builder.longType() + case DateType => builder.longType() + case TimestampType => builder.longType() + case FloatType => builder.floatType() + case DoubleType => builder.doubleType() + case _: DecimalType | StringType => builder.stringType() + case BinaryType => builder.bytesType() + case ArrayType(et, containsNull) => + builder.array().items(toAvroType(et, containsNull, recordName, prevNameSpace)) + case MapType(StringType, vt, valueContainsNull) => + builder.map().values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace)) + case st: StructType => + val nameSpace = s"$prevNameSpace.$recordName" + val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields() + st.foreach { f => + val fieldAvroType = toAvroType(f.dataType, f.nullable, f.name, nameSpace) + fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault() + } + fieldsAssembler.endRecord() + + // This should never happen. + case other => throw new IncompatibleSchemaException(s"Unexpected type $other.") + } + } +} + +class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala new file mode 100644 index 0000000000000..ec0ddc778c8f6 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ + +import scala.util.control.NonFatal + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.avro.Schema +import org.slf4j.LoggerFactory + +class SerializableSchema(@transient var value: Schema) + extends Serializable with KryoSerializable { + + @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) + + private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { + out.defaultWriteObject() + out.writeUTF(value.toString()) + out.flush() + } + + private def readObject(in: ObjectInputStream): Unit = tryOrIOException { + val json = in.readUTF() + value = new Schema.Parser().parse(json) + } + + private def tryOrIOException[T](block: => T): T = { + try { + block + } catch { + case e: IOException => + log.error("Exception encountered", e) + throw e + case NonFatal(e) => + log.error("Exception encountered", e) + throw new IOException(e) + } + } + + def write(kryo: Kryo, out: Output): Unit = { + val dos = new DataOutputStream(out) + dos.writeUTF(value.toString()) + dos.flush() + } + + def read(kryo: Kryo, in: Input): Unit = { + val dis = new DataInputStream(in) + val json = dis.readUTF() + value = new Schema.Parser().parse(json) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala new file mode 100755 index 0000000000000..b3c8a669cf820 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +package object avro { + /** + * Adds a method, `avro`, to DataFrameWriter that allows you to write avro files using + * the DataFileWriter + */ + implicit class AvroDataFrameWriter[T](writer: DataFrameWriter[T]) { + def avro: String => Unit = writer.format("avro").save + } + + /** + * Adds a method, `avro`, to DataFrameReader that allows you to read avro files using + * the DataFileReader + */ + implicit class AvroDataFrameReader(reader: DataFrameReader) { + def avro: String => DataFrame = reader.format("avro").load + + @scala.annotation.varargs + def avro(sources: String*): DataFrame = reader.format("avro").load(sources: _*) + } +} diff --git a/external/avro/src/test/resources/episodes.avro b/external/avro/src/test/resources/episodes.avro new file mode 100644 index 0000000000000..58a028ce19e6a Binary files /dev/null and b/external/avro/src/test/resources/episodes.avro differ diff --git a/external/avro/src/test/resources/log4j.properties b/external/avro/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..c18a724007c27 --- /dev/null +++ b/external/avro/src/test/resources/log4j.properties @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file core/target/unit-tests.log +log4j.rootLogger=DEBUG, CA, FA + +#Console Appender +log4j.appender.CA=org.apache.log4j.ConsoleAppender +log4j.appender.CA.layout=org.apache.log4j.PatternLayout +log4j.appender.CA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c: %m%n +log4j.appender.CA.Threshold = WARN + + +#File Appender +log4j.appender.FA=org.apache.log4j.FileAppender +log4j.appender.FA.append=false +log4j.appender.FA.file=target/unit-tests.log +log4j.appender.FA.layout=org.apache.log4j.PatternLayout +log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Set the logger level of File Appender to WARN +log4j.appender.FA.Threshold = INFO + +# Some packages are noisy for no good reason. +log4j.additivity.parquet.hadoop.ParquetRecordReader=false +log4j.logger.parquet.hadoop.ParquetRecordReader=OFF + +log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false +log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF + +log4j.additivity.org.apache.hadoop.hive.metastore.RetryingHMSHandler=false +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF + +log4j.additivity.hive.ql.metadata.Hive=false +log4j.logger.hive.ql.metadata.Hive=OFF \ No newline at end of file diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro new file mode 100755 index 0000000000000..fece892444979 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro new file mode 100755 index 0000000000000..1ca623a07dcf3 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro new file mode 100755 index 0000000000000..a12e9459e7461 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro new file mode 100755 index 0000000000000..60c095691d5d5 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro new file mode 100755 index 0000000000000..af56dfc8083dc Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro new file mode 100755 index 0000000000000..87d78447526f9 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro new file mode 100755 index 0000000000000..c326fc434bf18 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro new file mode 100755 index 0000000000000..279f36c317eb8 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro new file mode 100755 index 0000000000000..8d70f5d1274d4 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro new file mode 100755 index 0000000000000..6839d7217e492 Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro differ diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro new file mode 100755 index 0000000000000..aedc7f7e0e61c Binary files /dev/null and b/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro differ diff --git a/external/avro/src/test/resources/test.avro b/external/avro/src/test/resources/test.avro new file mode 100644 index 0000000000000..6425e2107e304 Binary files /dev/null and b/external/avro/src/test/resources/test.avro differ diff --git a/external/avro/src/test/resources/test.avsc b/external/avro/src/test/resources/test.avsc new file mode 100644 index 0000000000000..d7119a01f6aa0 --- /dev/null +++ b/external/avro/src/test/resources/test.avsc @@ -0,0 +1,53 @@ +{ + "type" : "record", + "name" : "test_schema", + "fields" : [{ + "name" : "string", + "type" : "string", + "doc" : "Meaningless string of characters" + }, { + "name" : "simple_map", + "type" : {"type": "map", "values": "int"} + }, { + "name" : "complex_map", + "type" : {"type": "map", "values": {"type": "map", "values": "string"}} + }, { + "name" : "union_string_null", + "type" : ["null", "string"] + }, { + "name" : "union_int_long_null", + "type" : ["int", "long", "null"] + }, { + "name" : "union_float_double", + "type" : ["float", "double"] + }, { + "name": "fixed3", + "type": {"type": "fixed", "size": 3, "name": "fixed3"} + }, { + "name": "fixed2", + "type": {"type": "fixed", "size": 2, "name": "fixed2"} + }, { + "name": "enum", + "type": { "type": "enum", + "name": "Suit", + "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + } + }, { + "name": "record", + "type": { + "type": "record", + "name": "record", + "aliases": ["RecordAlias"], + "fields" : [{ + "name": "value_field", + "type": "string" + }] + } + }, { + "name": "array_of_boolean", + "type": {"type": "array", "items": "boolean"} + }, { + "name": "bytes", + "type": "bytes" + }] +} diff --git a/external/avro/src/test/resources/test.json b/external/avro/src/test/resources/test.json new file mode 100644 index 0000000000000..780189a92b378 --- /dev/null +++ b/external/avro/src/test/resources/test.json @@ -0,0 +1,42 @@ +{ + "string": "OMG SPARK IS AWESOME", + "simple_map": {"abc": 1, "bcd": 7}, + "complex_map": {"key": {"a": "b", "c": "d"}}, + "union_string_null": {"string": "abc"}, + "union_int_long_null": {"int": 1}, + "union_float_double": {"float": 3.1415926535}, + "fixed3":"\u0002\u0003\u0004", + "fixed2":"\u0011\u0012", + "enum": "SPADES", + "record": {"value_field": "Two things are infinite: the universe and human stupidity; and I'm not sure about universe."}, + "array_of_boolean": [true, false, false], + "bytes": "\u0041\u0042\u0043" +} +{ + "string": "Terran is IMBA!", + "simple_map": {"mmm": 0, "qqq": 66}, + "complex_map": {"key": {"1": "2", "3": "4"}}, + "union_string_null": {"string": "123"}, + "union_int_long_null": {"long": 66}, + "union_float_double": {"double": 6.6666666666666}, + "fixed3":"\u0007\u0007\u0007", + "fixed2":"\u0001\u0002", + "enum": "CLUBS", + "record": {"value_field": "Life did not intend to make us perfect. Whoever is perfect belongs in a museum."}, + "array_of_boolean": [], + "bytes": "" +} +{ + "string": "The cake is a LIE!", + "simple_map": {}, + "complex_map": {"key": {}}, + "union_string_null": {"null": null}, + "union_int_long_null": {"null": null}, + "union_float_double": {"double": 0}, + "fixed3":"\u0011\u0022\u0009", + "fixed2":"\u0010\u0090", + "enum": "DIAMONDS", + "record": {"value_field": "TEST_STR123"}, + "array_of_boolean": [false], + "bytes": "\u0053" +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala new file mode 100644 index 0000000000000..6ed66563b987a --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -0,0 +1,803 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.nio.file.Files +import java.sql.{Date, Timestamp} +import java.util.{TimeZone, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.Schema.{Field, Type} +import org.apache.avro.file.DataFileWriter +import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} +import org.apache.commons.io.FileUtils + +import org.apache.spark.sql._ +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types._ + +class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + val episodesFile = "src/test/resources/episodes.avro" + val testFile = "src/test/resources/test.avro" + + override protected def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) + } + + def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { + val originalEntries = spark.read.avro(testFile).collect() + val newEntries = spark.read.avro(newFile) + checkAnswer(newEntries, originalEntries) + } + + test("reading from multiple paths") { + val df = spark.read.avro(episodesFile, episodesFile) + assert(df.count == 16) + } + + test("reading and writing partitioned data") { + val df = spark.read.avro(episodesFile) + val fields = List("title", "air_date", "doctor") + for (field <- fields) { + withTempPath { dir => + val outputDir = s"$dir/${UUID.randomUUID}" + df.write.partitionBy(field).avro(outputDir) + val input = spark.read.avro(outputDir) + // makes sure that no fields got dropped. + // We convert Rows to Seqs in order to work around SPARK-10325 + assert(input.select(field).collect().map(_.toSeq).toSet === + df.select(field).collect().map(_.toSeq).toSet) + } + } + } + + test("request no fields") { + val df = spark.read.avro(episodesFile) + df.createOrReplaceTempView("avro_table") + assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) + } + + test("convert formats") { + withTempPath { dir => + val df = spark.read.avro(episodesFile) + df.write.parquet(dir.getCanonicalPath) + assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) + } + } + + test("rearrange internal schema") { + withTempPath { dir => + val df = spark.read.avro(episodesFile) + df.select("doctor", "title").write.avro(dir.getCanonicalPath) + } + } + + test("test NULL avro type") { + withTempPath { dir => + val fields = + Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + avroRec.put("null", null) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + intercept[IncompatibleSchemaException] { + spark.read.avro(s"$dir.avro") + } + } + } + + test("union(int, long) is read as long") { + withTempPath { dir => + val avroSchema: Schema = { + val union = + Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toLong) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", 2) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.avro(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", LongType, nullable = true))) + assert(df.collect().toSet == Set(Row(1L), Row(2L))) + } + } + + test("union(float, double) is read as double") { + withTempPath { dir => + val avroSchema: Schema = { + val union = + Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toFloat) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", 2.toDouble) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.avro(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) + assert(df.collect().toSet == Set(Row(1.toDouble), Row(2.toDouble))) + } + } + + test("union(float, double, null) is read as nullable double") { + withTempPath { dir => + val avroSchema: Schema = { + val union = Schema.createUnion( + List(Schema.create(Type.FLOAT), + Schema.create(Type.DOUBLE), + Schema.create(Type.NULL) + ).asJava + ) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toFloat) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", null) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.avro(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) + assert(df.collect().toSet == Set(Row(1.toDouble), Row(null))) + } + } + + test("Union of a single type") { + withTempPath { dir => + val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) + val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + + avroRec.put("field1", 8) + + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.read.avro(s"$dir.avro") + assert(df.first() == Row(8)) + } + } + + test("Complex Union Type") { + withTempPath { dir => + val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) + val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) + val complexUnionType = Schema.createUnion( + List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) + val fields = Seq( + new Field("field1", complexUnionType, "doc", null), + new Field("field2", complexUnionType, "doc", null), + new Field("field3", complexUnionType, "doc", null), + new Field("field4", complexUnionType, "doc", null) + ).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + val field1 = 1234 + val field2 = "Hope that was not load bearing" + val field3 = Array[Byte](1, 2, 3, 4) + val field4 = "e2" + avroRec.put("field1", field1) + avroRec.put("field2", field2) + avroRec.put("field3", new Fixed(fixedSchema, field3)) + avroRec.put("field4", new EnumSymbol(enumSchema, field4)) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.sqlContext.read.avro(s"$dir.avro") + assertResult(field1)(df.selectExpr("field1.member0").first().get(0)) + assertResult(field2)(df.selectExpr("field2.member1").first().get(0)) + assertResult(field3)(df.selectExpr("field3.member2").first().get(0)) + assertResult(field4)(df.selectExpr("field4.member3").first().get(0)) + } + } + + test("Lots of nulls") { + withTempPath { dir => + val schema = StructType(Seq( + StructField("binary", BinaryType, true), + StructField("timestamp", TimestampType, true), + StructField("array", ArrayType(ShortType), true), + StructField("map", MapType(StringType, StringType), true), + StructField("struct", StructType(Seq(StructField("int", IntegerType, true)))))) + val rdd = spark.sparkContext.parallelize(Seq[Row]( + Row(null, new Timestamp(1), Array[Short](1, 2, 3), null, null), + Row(null, null, null, null, null), + Row(null, null, null, null, null), + Row(null, null, null, null, null))) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + } + } + + test("Struct field type") { + withTempPath { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("short", ShortType, true), + StructField("byte", ByteType, true), + StructField("boolean", BooleanType, true) + )) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, 1.toShort, 1.toByte, true), + Row(2f, 2.toShort, 2.toByte, true), + Row(3f, 3.toShort, 3.toByte, true) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + } + } + + test("Date field type") { + withTempPath { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("date", DateType, true) + )) + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, null), + Row(2f, new Date(1451948400000L)), + Row(3f, new Date(1460066400500L)) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + assert(spark.read.avro(dir.toString).select("date").collect().map(_(0)).toSet == + Array(null, 1451865600000L, 1459987200000L).toSet) + } + } + + test("Array data types") { + withTempPath { dir => + val testSchema = StructType(Seq( + StructField("byte_array", ArrayType(ByteType), true), + StructField("short_array", ArrayType(ShortType), true), + StructField("float_array", ArrayType(FloatType), true), + StructField("bool_array", ArrayType(BooleanType), true), + StructField("long_array", ArrayType(LongType), true), + StructField("double_array", ArrayType(DoubleType), true), + StructField("decimal_array", ArrayType(DecimalType(10, 0)), true), + StructField("bin_array", ArrayType(BinaryType), true), + StructField("timestamp_array", ArrayType(TimestampType), true), + StructField("array_array", ArrayType(ArrayType(StringType), true), true), + StructField("struct_array", ArrayType( + StructType(Seq(StructField("name", StringType, true))))))) + + val arrayOfByte = new Array[Byte](4) + for (i <- arrayOfByte.indices) { + arrayOfByte(i) = i.toByte + } + + val rdd = spark.sparkContext.parallelize(Seq( + Row(arrayOfByte, Array[Short](1, 2, 3, 4), Array[Float](1f, 2f, 3f, 4f), + Array[Boolean](true, false, true, false), Array[Long](1L, 2L), Array[Double](1.0, 2.0), + Array[BigDecimal](BigDecimal.valueOf(3)), Array[Array[Byte]](arrayOfByte, arrayOfByte), + Array[Timestamp](new Timestamp(0)), + Array[Array[String]](Array[String]("CSH, tearing down the walls that divide us", "-jd")), + Array[Row](Row("Bobby G. can't swim"))))) + val df = spark.createDataFrame(rdd, testSchema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + } + } + + test("write with compression") { + withTempPath { dir => + val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" + val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" + val uncompressDir = s"$dir/uncompress" + val deflateDir = s"$dir/deflate" + val snappyDir = s"$dir/snappy" + + val df = spark.read.avro(testFile) + spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") + df.write.avro(uncompressDir) + spark.conf.set(AVRO_COMPRESSION_CODEC, "deflate") + spark.conf.set(AVRO_DEFLATE_LEVEL, "9") + df.write.avro(deflateDir) + spark.conf.set(AVRO_COMPRESSION_CODEC, "snappy") + df.write.avro(snappyDir) + + val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir)) + val deflateSize = FileUtils.sizeOfDirectory(new File(deflateDir)) + val snappySize = FileUtils.sizeOfDirectory(new File(snappyDir)) + + assert(uncompressSize > deflateSize) + assert(snappySize > deflateSize) + } + } + + test("dsl test") { + val results = spark.read.avro(episodesFile).select("title").collect() + assert(results.length === 8) + } + + test("support of various data types") { + // This test uses data from test.avro. You can see the data and the schema of this file in + // test.json and test.avsc + val all = spark.read.avro(testFile).collect() + assert(all.length == 3) + + val str = spark.read.avro(testFile).select("string").collect() + assert(str.map(_(0)).toSet.contains("Terran is IMBA!")) + + val simple_map = spark.read.avro(testFile).select("simple_map").collect() + assert(simple_map(0)(0).getClass.toString.contains("Map")) + assert(simple_map.map(_(0).asInstanceOf[Map[String, Some[Int]]].size).toSet == Set(2, 0)) + + val union0 = spark.read.avro(testFile).select("union_string_null").collect() + assert(union0.map(_(0)).toSet == Set("abc", "123", null)) + + val union1 = spark.read.avro(testFile).select("union_int_long_null").collect() + assert(union1.map(_(0)).toSet == Set(66, 1, null)) + + val union2 = spark.read.avro(testFile).select("union_float_double").collect() + assert( + union2 + .map(x => new java.lang.Double(x(0).toString)) + .exists(p => Math.abs(p - Math.PI) < 0.001)) + + val fixed = spark.read.avro(testFile).select("fixed3").collect() + assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3)) + + val enum = spark.read.avro(testFile).select("enum").collect() + assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS")) + + val record = spark.read.avro(testFile).select("record").collect() + assert(record(0)(0).getClass.toString.contains("Row")) + assert(record.map(_(0).asInstanceOf[Row](0)).contains("TEST_STR123")) + + val array_of_boolean = spark.read.avro(testFile).select("array_of_boolean").collect() + assert(array_of_boolean.map(_(0).asInstanceOf[Seq[Boolean]].size).toSet == Set(3, 1, 0)) + + val bytes = spark.read.avro(testFile).select("bytes").collect() + assert(bytes.map(_(0).asInstanceOf[Array[Byte]].length).toSet == Set(3, 1, 0)) + } + + test("sql test") { + spark.sql( + s""" + |CREATE TEMPORARY VIEW avroTable + |USING avro + |OPTIONS (path "$episodesFile") + """.stripMargin.replaceAll("\n", " ")) + + assert(spark.sql("SELECT * FROM avroTable").collect().length === 8) + } + + test("conversion to avro and back") { + // Note that test.avro includes a variety of types, some of which are nullable. We expect to + // get the same values back. + withTempPath { dir => + val avroDir = s"$dir/avro" + spark.read.avro(testFile).write.avro(avroDir) + checkReloadMatchesSaved(testFile, avroDir) + } + } + + test("conversion to avro and back with namespace") { + // Note that test.avro includes a variety of types, some of which are nullable. We expect to + // get the same values back. + withTempPath { tempDir => + val name = "AvroTest" + val namespace = "com.databricks.spark.avro" + val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) + + val avroDir = tempDir + "/namedAvro" + spark.read.avro(testFile).write.options(parameters).avro(avroDir) + checkReloadMatchesSaved(testFile, avroDir) + + // Look at raw file and make sure has namespace info + val rawSaved = spark.sparkContext.textFile(avroDir) + val schema = rawSaved.collect().mkString("") + assert(schema.contains(name)) + assert(schema.contains(namespace)) + } + } + + test("converting some specific sparkSQL types to avro") { + withTempPath { tempDir => + val testSchema = StructType(Seq( + StructField("Name", StringType, false), + StructField("Length", IntegerType, true), + StructField("Time", TimestampType, false), + StructField("Decimal", DecimalType(10, 2), true), + StructField("Binary", BinaryType, false))) + + val arrayOfByte = new Array[Byte](4) + for (i <- arrayOfByte.indices) { + arrayOfByte(i) = i.toByte + } + val cityRDD = spark.sparkContext.parallelize(Seq( + Row("San Francisco", 12, new Timestamp(666), null, arrayOfByte), + Row("Palo Alto", null, new Timestamp(777), null, arrayOfByte), + Row("Munich", 8, new Timestamp(42), Decimal(3.14), arrayOfByte))) + val cityDataFrame = spark.createDataFrame(cityRDD, testSchema) + + val avroDir = tempDir + "/avro" + cityDataFrame.write.avro(avroDir) + assert(spark.read.avro(avroDir).collect().length == 3) + + // TimesStamps are converted to longs + val times = spark.read.avro(avroDir).select("Time").collect() + assert(times.map(_(0)).toSet == Set(666, 777, 42)) + + // DecimalType should be converted to string + val decimals = spark.read.avro(avroDir).select("Decimal").collect() + assert(decimals.map(_(0)).contains("3.14")) + + // There should be a null entry + val length = spark.read.avro(avroDir).select("Length").collect() + assert(length.map(_(0)).contains(null)) + + val binary = spark.read.avro(avroDir).select("Binary").collect() + for (i <- arrayOfByte.indices) { + assert(binary(1)(0).asInstanceOf[Array[Byte]](i) == arrayOfByte(i)) + } + } + } + + test("correctly read long as date/timestamp type") { + withTempPath { tempDir => + val sparkSession = spark + import sparkSession.implicits._ + + val currentTime = new Timestamp(System.currentTimeMillis()) + val currentDate = new Date(System.currentTimeMillis()) + val schema = StructType(Seq( + StructField("_1", DateType, false), StructField("_2", TimestampType, false))) + val writeDs = Seq((currentDate, currentTime)).toDS + + val avroDir = tempDir + "/avro" + writeDs.write.avro(avroDir) + assert(spark.read.avro(avroDir).collect().length == 1) + + val readDs = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)] + + assert(readDs.collect().sameElements(writeDs.collect())) + } + } + + test("support of globbed paths") { + val e1 = spark.read.avro("*/test/resources/episodes.avro").collect() + assert(e1.length == 8) + + val e2 = spark.read.avro("src/*/*/episodes.avro").collect() + assert(e2.length == 8) + } + + test("does not coerce null date/timestamp value to 0 epoch.") { + withTempPath { tempDir => + val sparkSession = spark + import sparkSession.implicits._ + + val nullTime: Timestamp = null + val nullDate: Date = null + val schema = StructType(Seq( + StructField("_1", DateType, nullable = true), + StructField("_2", TimestampType, nullable = true)) + ) + val writeDs = Seq((nullDate, nullTime)).toDS + + val avroDir = tempDir + "/avro" + writeDs.write.avro(avroDir) + val readValues = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)].collect + + assert(readValues.size == 1) + assert(readValues.head == ((nullDate, nullTime))) + } + } + + test("support user provided avro schema") { + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name" : "string", + | "type" : "string", + | "doc" : "Meaningless string of characters" + | }] + |} + """.stripMargin + val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema).avro(testFile).collect() + val expected = spark.read.avro(testFile).select("string").collect() + assert(result.sameElements(expected)) + } + + test("support user provided avro schema with defaults for missing fields") { + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name" : "missingField", + | "type" : "string", + | "default" : "foo" + | }] + |} + """.stripMargin + val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema) + .avro(testFile).select("missingField").first + assert(result === Row("foo")) + } + + test("reading from invalid path throws exception") { + + // Directory given has no avro files + intercept[AnalysisException] { + withTempPath(dir => spark.read.avro(dir.getCanonicalPath)) + } + + intercept[AnalysisException] { + spark.read.avro("very/invalid/path/123.avro") + } + + // In case of globbed path that can't be matched to anything, another exception is thrown (and + // exception message is helpful) + intercept[AnalysisException] { + spark.read.avro("*/*/*/*/*/*/*/something.avro") + } + + intercept[FileNotFoundException] { + withTempPath { dir => + FileUtils.touch(new File(dir, "test")) + spark.read.avro(dir.toString) + } + } + + } + + test("SQL test insert overwrite") { + withTempPath { tempDir => + val tempEmptyDir = s"$tempDir/sqlOverwrite" + // Create a temp directory for table that will be overwritten + new File(tempEmptyDir).mkdirs() + spark.sql( + s""" + |CREATE TEMPORARY VIEW episodes + |USING avro + |OPTIONS (path "$episodesFile") + """.stripMargin.replaceAll("\n", " ")) + spark.sql( + s""" + |CREATE TEMPORARY VIEW episodesEmpty + |(name string, air_date string, doctor int) + |USING avro + |OPTIONS (path "$tempEmptyDir") + """.stripMargin.replaceAll("\n", " ")) + + assert(spark.sql("SELECT * FROM episodes").collect().length === 8) + assert(spark.sql("SELECT * FROM episodesEmpty").collect().isEmpty) + + spark.sql( + s""" + |INSERT OVERWRITE TABLE episodesEmpty + |SELECT * FROM episodes + """.stripMargin.replaceAll("\n", " ")) + assert(spark.sql("SELECT * FROM episodesEmpty").collect().length == 8) + } + } + + test("test save and load") { + // Test if load works as expected + withTempPath { tempDir => + val df = spark.read.avro(episodesFile) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + + df.write.avro(tempSaveDir) + val newDf = spark.read.avro(tempSaveDir) + assert(newDf.count == 8) + } + } + + test("test load with non-Avro file") { + // Test if load works as expected + withTempPath { tempDir => + val df = spark.read.avro(episodesFile) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + df.write.avro(tempSaveDir) + + Files.createFile(new File(tempSaveDir, "non-avro").toPath) + + val newDf = spark + .read + .option(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + .avro(tempSaveDir) + + assert(newDf.count == 8) + } + } + + test("read avro with user defined schema: read partial columns") { + val partialColumns = StructType(Seq( + StructField("string", StringType, false), + StructField("simple_map", MapType(StringType, IntegerType), false), + StructField("complex_map", MapType(StringType, MapType(StringType, StringType)), false), + StructField("union_string_null", StringType, true), + StructField("union_int_long_null", LongType, true), + StructField("fixed3", BinaryType, true), + StructField("fixed2", BinaryType, true), + StructField("enum", StringType, false), + StructField("record", StructType(Seq(StructField("value_field", StringType, false))), false), + StructField("array_of_boolean", ArrayType(BooleanType), false), + StructField("bytes", BinaryType, true))) + val withSchema = spark.read.schema(partialColumns).avro(testFile).collect() + val withOutSchema = spark + .read + .avro(testFile) + .select("string", "simple_map", "complex_map", "union_string_null", "union_int_long_null", + "fixed3", "fixed2", "enum", "record", "array_of_boolean", "bytes") + .collect() + assert(withSchema.sameElements(withOutSchema)) + } + + test("read avro with user defined schema: read non-exist columns") { + val schema = + StructType( + Seq( + StructField("non_exist_string", StringType, true), + StructField( + "record", + StructType(Seq( + StructField("non_exist_field", StringType, false), + StructField("non_exist_field2", StringType, false))), + false))) + val withEmptyColumn = spark.read.schema(schema).avro(testFile).collect() + + assert(withEmptyColumn.forall(_ == Row(null: String, Row(null: String, null: String)))) + } + + test("read avro file partitioned") { + withTempPath { dir => + val sparkSession = spark + import sparkSession.implicits._ + val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") + val outputDir = s"$dir/${UUID.randomUUID}" + df.write.avro(outputDir) + val input = spark.read.avro(outputDir) + assert(input.collect.toSet.size === 1024 * 3 + 1) + assert(input.rdd.partitions.size > 2) + } + } + + case class NestedBottom(id: Int, data: String) + + case class NestedMiddle(id: Int, data: NestedBottom) + + case class NestedTop(id: Int, data: NestedMiddle) + + test("saving avro that has nested records with the same name") { + withTempPath { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) + val outputFolder = s"$tempDir/duplicate_names/" + writeDf.write.avro(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.avro(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + case class NestedMiddleArray(id: Int, data: Array[NestedBottom]) + + case class NestedTopArray(id: Int, data: NestedMiddleArray) + + test("saving avro that has nested records with the same name inside an array") { + withTempPath { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame( + List(NestedTopArray(1, NestedMiddleArray(2, Array( + NestedBottom(3, "1"), NestedBottom(4, "2") + )))) + ) + val outputFolder = s"$tempDir/duplicate_names_array/" + writeDf.write.avro(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.avro(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + case class NestedMiddleMap(id: Int, data: Map[String, NestedBottom]) + + case class NestedTopMap(id: Int, data: NestedMiddleMap) + + test("saving avro that has nested records with the same name inside a map") { + withTempPath { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame( + List(NestedTopMap(1, NestedMiddleMap(2, Map( + "1" -> NestedBottom(3, "1"), "2" -> NestedBottom(4, "2") + )))) + ) + val outputFolder = s"$tempDir/duplicate_names_map/" + writeDf.write.avro(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.avro(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala new file mode 100755 index 0000000000000..a0f88515ed9d4 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} + +class SerializableConfigurationSuite extends SparkFunSuite { + + private def testSerialization(serializer: SerializerInstance): Unit = { + import AvroFileFormat.SerializableConfiguration + val conf = new SerializableConfiguration(new Configuration()) + + val serialized = serializer.serialize(conf) + + serializer.deserialize[Any](serialized) match { + case c: SerializableConfiguration => + assert(c.log != null, "log was null") + assert(c.value != null, "value was null") + case other => fail( + s"Expecting ${classOf[SerializableConfiguration]}, but got ${other.getClass}.") + } + } + + test("serialization with JavaSerializer") { + testSerialization(new JavaSerializer(new SparkConf()).newInstance()) + } + + test("serialization with KryoSerializer") { + testSerialization(new KryoSerializer(new SparkConf()).newInstance()) + } + +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala new file mode 100644 index 0000000000000..510bcbdd31929 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} + +class SerializableSchemaSuite extends SparkFunSuite { + + private def testSerialization(serializer: SerializerInstance): Unit = { + val avroTypeJson = + s""" + |{ + | "type": "string", + | "name": "my_string" + |} + """.stripMargin + val avroSchema = new Schema.Parser().parse(avroTypeJson) + val serializableSchema = new SerializableSchema(avroSchema) + val serialized = serializer.serialize(serializableSchema) + + serializer.deserialize[Any](serialized) match { + case c: SerializableSchema => + assert(c.log != null, "log was null") + assert(c.value != null, "value was null") + assert(c.value == avroSchema) + case other => fail( + s"Expecting ${classOf[SerializableSchema]}, but got ${other.getClass}.") + } + } + + test("serialization with JavaSerializer") { + testSerialization(new JavaSerializer(new SparkConf()).newInstance()) + } + + test("serialization with KryoSerializer") { + testSerialization(new KryoSerializer(new SparkConf()).newInstance()) + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index f26c134c2f6e9..badaa69cc303c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType /** @@ -86,7 +86,7 @@ class KafkaContinuousReader( KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { import scala.collection.JavaConverters._ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) @@ -106,9 +106,9 @@ class KafkaContinuousReader( startOffsets.toSeq.map { case (topicPartition, start) => - KafkaContinuousDataReaderFactory( + KafkaContinuousInputPartition( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[DataReaderFactory[UnsafeRow]] + .asInstanceOf[InputPartition[UnsafeRow]] }.asJava } @@ -146,7 +146,7 @@ class KafkaContinuousReader( } /** - * A data reader factory for continuous Kafka processing. This will be serialized and transformed + * An input partition for continuous Kafka processing. This will be serialized and transformed * into a full reader on executors. * * @param topicPartition The (topic, partition) pair this task is responsible for. @@ -156,23 +156,23 @@ class KafkaContinuousReader( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -case class KafkaContinuousDataReaderFactory( +case class KafkaContinuousInputPartition( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] { - override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = { + override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = { val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] require(kafkaOffset.topicPartition == topicPartition, s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") - new KafkaContinuousDataReader( + new KafkaContinuousInputPartitionReader( topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) } - override def createDataReader(): KafkaContinuousDataReader = { - new KafkaContinuousDataReader( + override def createPartitionReader(): KafkaContinuousInputPartitionReader = { + new KafkaContinuousInputPartitionReader( topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) } } @@ -187,12 +187,12 @@ case class KafkaContinuousDataReaderFactory( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -class KafkaContinuousDataReader( +class KafkaContinuousInputPartitionReader( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index 48508d057a540..941f0ab177e48 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -395,7 +395,7 @@ private[kafka010] object KafkaDataConsumer extends Logging { // likely running on a beefy machine that can handle a large number of simultaneously // active consumers. - if (entry.getValue.inUse == false && this.size > capacity) { + if (!entry.getValue.inUse && this.size > capacity) { logWarning( s"KafkaConsumer cache hitting max capacity of $capacity, " + s"removing consumer for ${entry.getKey}") diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 2ed49ba3f5495..737da2e51b125 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread @@ -68,7 +68,7 @@ private[kafka010] class KafkaMicroBatchReader( private val pollTimeoutMs = options.getLong( "kafkaConsumer.pollTimeoutMs", - SparkEnv.get.conf.getTimeAsMs("spark.network.timeout", "120s")) + SparkEnv.get.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L) private val maxOffsetsPerTrigger = Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) @@ -101,7 +101,7 @@ private[kafka010] class KafkaMicroBatchReader( } } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -143,10 +143,10 @@ private[kafka010] class KafkaMicroBatchReader( // Generate factories based on the offset ranges val factories = offsetRanges.map { range => - new KafkaMicroBatchDataReaderFactory( + new KafkaMicroBatchInputPartition( range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } - factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava + factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava } override def getStartOffset: Offset = { @@ -169,7 +169,7 @@ private[kafka010] class KafkaMicroBatchReader( kafkaOffsetReader.close() } - override def toString(): String = s"Kafka[$kafkaOffsetReader]" + override def toString(): String = s"KafkaV2[$kafkaOffsetReader]" /** * Read initial partition offsets from the checkpoint, or decide the offsets and write them to @@ -299,27 +299,28 @@ private[kafka010] class KafkaMicroBatchReader( } } -/** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchDataReaderFactory( +/** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchInputPartition( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] { + reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] { override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray - override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( - offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) + override def createPartitionReader(): InputPartitionReader[UnsafeRow] = + new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs, + failOnDataLoss, reuseKafkaConsumer) } -/** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchDataReader( +/** A [[InputPartitionReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchInputPartitionReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { + reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging { private val consumer = KafkaDataConsumer.acquire( offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 551641cfdbca8..82066697cb95a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -75,7 +75,17 @@ private[kafka010] class KafkaOffsetReader( * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the * offsets and never commits them. */ - protected var consumer = createConsumer() + @volatile protected var _consumer: Consumer[Array[Byte], Array[Byte]] = null + + protected def consumer: Consumer[Array[Byte], Array[Byte]] = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer == null) { + val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) + newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) + _consumer = consumerStrategy.createConsumer(newKafkaParams) + } + _consumer + } private val maxOffsetFetchAttempts = readerOptions.getOrElse("fetchOffset.numRetries", "3").toInt @@ -95,9 +105,7 @@ private[kafka010] class KafkaOffsetReader( * Closes the connection to Kafka, and cleans up state. */ def close(): Unit = { - runUninterruptibly { - consumer.close() - } + if (_consumer != null) runUninterruptibly { stopConsumer() } kafkaReaderThread.shutdown() } @@ -304,19 +312,14 @@ private[kafka010] class KafkaOffsetReader( } } - /** - * Create a consumer using the new generated group id. We always use a new consumer to avoid - * just using a broken consumer to retry on Kafka errors, which likely will fail again. - */ - private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized { - val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) - newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) - consumerStrategy.createConsumer(newKafkaParams) + private def stopConsumer(): Unit = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer != null) _consumer.close() } private def resetConsumer(): Unit = synchronized { - consumer.close() - consumer = createConsumer() + stopConsumer() + _consumer = null // will automatically get reinitialized again } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index 7103709969c18..c31e6ed3e0903 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -48,7 +48,9 @@ private[kafka010] class KafkaRelation( private val pollTimeoutMs = sourceOptions.getOrElse( "kafkaConsumer.pollTimeoutMs", - sqlContext.sparkContext.conf.getTimeAsMs("spark.network.timeout", "120s").toString + (sqlContext.sparkContext.conf.getTimeAsSeconds( + "spark.network.timeout", + "120s") * 1000L).toString ).toLong override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 1c7b3a29a861f..101e649727fcf 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -84,7 +84,7 @@ private[kafka010] class KafkaSource( private val pollTimeoutMs = sourceOptions.getOrElse( "kafkaConsumer.pollTimeoutMs", - sc.conf.getTimeAsMs("spark.network.timeout", "120s").toString + (sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L).toString ).toLong private val maxOffsetsPerTrigger = diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 36b9f0466566b..d225c1ea6b7f1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSessio import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -149,7 +150,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader]] to read + * Creates a [[ContinuousInputPartitionReader]] to read * Kafka data in a continuous streaming query. */ override def createContinuousReader( diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index ae5b5c52d514e..32923dc9f5a6b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -67,7 +67,7 @@ case class KafkaStreamWriterFactory( override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e017fd9b84d21..c6412eac97dba 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -563,7 +563,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { ) } - test("ensure stream-stream self-join generates only one offset in offset log") { + test("ensure stream-stream self-join generates only one offset in log and correct metrics") { val topic = newTopic() testUtils.createTopic(topic, partitions = 2) require(testUtils.getLatestOffsets(Set(topic)).size === 2) @@ -587,7 +587,12 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { AddKafkaData(Set(topic), 1, 2), CheckAnswer((1, 1, 1), (2, 2, 2)), AddKafkaData(Set(topic), 6, 3), - CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)) + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)), + AssertOnQuery { q => + assert(q.availableOffsets.iterator.size == 1) + assert(q.recentProgress.map(_.numInputRows).sum == 4) + true + } ) } } @@ -673,8 +678,8 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) ) - val factories = reader.createUnsafeRowReaderFactories().asScala - .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) + val factories = reader.planUnsafeInputPartitions().asScala + .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { assert(factories.size == numPartitionsGenerated) factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala deleted file mode 100644 index aeb8c1dc342b3..0000000000000 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.kafka010 - -import java.{ util => ju } - -import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer } -import org.apache.kafka.common.{ KafkaException, TopicPartition } - -import org.apache.spark.internal.Logging - -/** - * Consumer of single topicpartition, intended for cached reuse. - * Underlying consumer is not threadsafe, so neither is this, - * but processing the same topicpartition and group id in multiple threads is usually bad anyway. - */ -private[kafka010] -class CachedKafkaConsumer[K, V] private( - val groupId: String, - val topic: String, - val partition: Int, - val kafkaParams: ju.Map[String, Object]) extends Logging { - - require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), - "groupId used for cache key must match the groupId in kafkaParams") - - val topicPartition = new TopicPartition(topic, partition) - - protected val consumer = { - val c = new KafkaConsumer[K, V](kafkaParams) - val tps = new ju.ArrayList[TopicPartition]() - tps.add(topicPartition) - c.assign(tps) - c - } - - // TODO if the buffer was kept around as a random-access structure, - // could possibly optimize re-calculating of an RDD in the same batch - protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() - protected var nextOffset = -2L - - def close(): Unit = consumer.close() - - /** - * Get the record for the given offset, waiting up to timeout ms if IO is necessary. - * Sequential forward access will use buffers, but random access will be horribly inefficient. - */ - def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { - logDebug(s"Get $groupId $topic $partition nextOffset $nextOffset requested $offset") - if (offset != nextOffset) { - logInfo(s"Initial fetch for $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - } - - if (!buffer.hasNext()) { poll(timeout) } - require(buffer.hasNext(), - s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") - var record = buffer.next() - - if (record.offset != offset) { - logInfo(s"Buffer miss for $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - require(buffer.hasNext(), - s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") - record = buffer.next() - require(record.offset == offset, - s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " + - s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + - "spark.streaming.kafka.allowNonConsecutiveOffsets" - ) - } - - nextOffset = offset + 1 - record - } - - /** - * Start a batch on a compacted topic - */ - def compactedStart(offset: Long, timeout: Long): Unit = { - logDebug(s"compacted start $groupId $topic $partition starting $offset") - // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics - if (offset != nextOffset) { - logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - } - } - - /** - * Get the next record in the batch from a compacted topic. - * Assumes compactedStart has been called first, and ignores gaps. - */ - def compactedNext(timeout: Long): ConsumerRecord[K, V] = { - if (!buffer.hasNext()) { - poll(timeout) - } - require(buffer.hasNext(), - s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout") - val record = buffer.next() - nextOffset = record.offset + 1 - record - } - - /** - * Rewind to previous record in the batch from a compacted topic. - * @throws NoSuchElementException if no previous element - */ - def compactedPrevious(): ConsumerRecord[K, V] = { - buffer.previous() - } - - private def seek(offset: Long): Unit = { - logDebug(s"Seeking to $topicPartition $offset") - consumer.seek(topicPartition, offset) - } - - private def poll(timeout: Long): Unit = { - val p = consumer.poll(timeout) - val r = p.records(topicPartition) - logDebug(s"Polled ${p.partitions()} ${r.size}") - buffer = r.listIterator - } - -} - -private[kafka010] -object CachedKafkaConsumer extends Logging { - - private case class CacheKey(groupId: String, topic: String, partition: Int) - - // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap - private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = null - - /** Must be called before get, once per JVM, to configure the cache. Further calls are ignored */ - def init( - initialCapacity: Int, - maxCapacity: Int, - loadFactor: Float): Unit = CachedKafkaConsumer.synchronized { - if (null == cache) { - logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") - cache = new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]]( - initialCapacity, loadFactor, true) { - override def removeEldestEntry( - entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer[_, _]]): Boolean = { - if (this.size > maxCapacity) { - try { - entry.getValue.consumer.close() - } catch { - case x: KafkaException => - logError("Error closing oldest Kafka consumer", x) - } - true - } else { - false - } - } - } - } - } - - /** - * Get a cached consumer for groupId, assigned to topic and partition. - * If matching consumer doesn't already exist, will be created using kafkaParams. - */ - def get[K, V]( - groupId: String, - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = - CachedKafkaConsumer.synchronized { - val k = CacheKey(groupId, topic, partition) - val v = cache.get(k) - if (null == v) { - logInfo(s"Cache miss for $k") - logDebug(cache.keySet.toString) - val c = new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) - cache.put(k, c) - c - } else { - // any given topicpartition should have a consistent key and value type - v.asInstanceOf[CachedKafkaConsumer[K, V]] - } - } - - /** - * Get a fresh new instance, unassociated with the global cache. - * Caller is responsible for closing - */ - def getUncached[K, V]( - groupId: String, - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = - new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) - - /** remove consumer for given groupId, topic, and partition, if it exists */ - def remove(groupId: String, topic: String, partition: Int): Unit = { - val k = CacheKey(groupId, topic, partition) - logInfo(s"Removing $k from cache") - val v = CachedKafkaConsumer.synchronized { - cache.remove(k) - } - if (null != v) { - v.close() - logInfo(s"Removed $k from cache") - } - } -} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index c3221481556f5..0246006acf0bd 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -166,6 +166,8 @@ private[spark] class DirectKafkaInputDStream[K, V]( * which would throw off consumer position. Fix position if this happens. */ private def paranoidPoll(c: Consumer[K, V]): Unit = { + // don't actually want to consume any messages, so pause all partitions + c.pause(c.assignment()) val msgs = c.poll(0) if (!msgs.isEmpty) { // position should be minimum offset per topicpartition @@ -204,8 +206,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( // position for new partitions determined by auto.offset.reset if no commit currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap - // don't want to consume messages, so pause - c.pause(newPartitions.asJava) // find latest available offsets c.seekToEnd(currentOffsets.keySet.asJava) parts.map(tp => tp -> c.position(tp)).toMap @@ -262,9 +262,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( tp -> c.position(tp) }.toMap } - - // don't actually want to consume any messages, so pause all partitions - c.pause(currentOffsets.keySet.asJava) } override def stop(): Unit = this.synchronized { diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala new file mode 100644 index 0000000000000..68c5fe9ab066a --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer} +import org.apache.kafka.common.{KafkaException, TopicPartition} + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging + +private[kafka010] sealed trait KafkaDataConsumer[K, V] { + /** + * Get the record for the given offset if available. + * + * @param offset the offset to fetch. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[K, V] = { + internalConsumer.get(offset, pollTimeoutMs) + } + + /** + * Start a batch on a compacted topic + * + * @param offset the offset to fetch. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = { + internalConsumer.compactedStart(offset, pollTimeoutMs) + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + * + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = { + internalConsumer.compactedNext(pollTimeoutMs) + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + internalConsumer.compactedPrevious() + } + + /** + * Release this consumer from being further used. Depending on its implementation, + * this consumer will be either finalized, or reset for reuse later. + */ + def release(): Unit + + /** Reference to the internal implementation that this wrapper delegates to */ + def internalConsumer: InternalKafkaConsumer[K, V] +} + + +/** + * A wrapper around Kafka's KafkaConsumer. + * This is not for direct use outside this file. + */ +private[kafka010] class InternalKafkaConsumer[K, V]( + val topicPartition: TopicPartition, + val kafkaParams: ju.Map[String, Object]) extends Logging { + + private[kafka010] val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG) + .asInstanceOf[String] + + private val consumer = createConsumer + + /** indicates whether this consumer is in use or not */ + var inUse = true + + /** indicate whether this consumer is going to be stopped in the next release */ + var markedForClose = false + + // TODO if the buffer was kept around as a random-access structure, + // could possibly optimize re-calculating of an RDD in the same batch + @volatile private var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() + @volatile private var nextOffset = InternalKafkaConsumer.UNKNOWN_OFFSET + + override def toString: String = { + "InternalKafkaConsumer(" + + s"hash=${Integer.toHexString(hashCode)}, " + + s"groupId=$groupId, " + + s"topicPartition=$topicPartition)" + } + + /** Create a KafkaConsumer to fetch records for `topicPartition` */ + private def createConsumer: KafkaConsumer[K, V] = { + val c = new KafkaConsumer[K, V](kafkaParams) + val topics = ju.Arrays.asList(topicPartition) + c.assign(topics) + c + } + + def close(): Unit = consumer.close() + + /** + * Get the record for the given offset, waiting up to timeout ms if IO is necessary. + * Sequential forward access will use buffers, but random access will be horribly inefficient. + */ + def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { + logDebug(s"Get $groupId $topicPartition nextOffset $nextOffset requested $offset") + if (offset != nextOffset) { + logInfo(s"Initial fetch for $groupId $topicPartition $offset") + seek(offset) + poll(timeout) + } + + if (!buffer.hasNext()) { + poll(timeout) + } + require(buffer.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout") + var record = buffer.next() + + if (record.offset != offset) { + logInfo(s"Buffer miss for $groupId $topicPartition $offset") + seek(offset) + poll(timeout) + require(buffer.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout") + record = buffer.next() + require(record.offset == offset, + s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset " + + s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + + "spark.streaming.kafka.allowNonConsecutiveOffsets" + ) + } + + nextOffset = offset + 1 + record + } + + /** + * Start a batch on a compacted topic + */ + def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = { + logDebug(s"compacted start $groupId $topicPartition starting $offset") + // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics + if (offset != nextOffset) { + logInfo(s"Initial fetch for compacted $groupId $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + } + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + */ + def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = { + if (!buffer.hasNext()) { + poll(pollTimeoutMs) + } + require(buffer.hasNext(), + s"Failed to get records for compacted $groupId $topicPartition " + + s"after polling for $pollTimeoutMs") + val record = buffer.next() + nextOffset = record.offset + 1 + record + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + buffer.previous() + } + + private def seek(offset: Long): Unit = { + logDebug(s"Seeking to $topicPartition $offset") + consumer.seek(topicPartition, offset) + } + + private def poll(timeout: Long): Unit = { + val p = consumer.poll(timeout) + val r = p.records(topicPartition) + logDebug(s"Polled ${p.partitions()} ${r.size}") + buffer = r.listIterator + } + +} + +private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition) + +private[kafka010] object KafkaDataConsumer extends Logging { + + private case class CachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V]) + extends KafkaDataConsumer[K, V] { + assert(internalConsumer.inUse) + override def release(): Unit = KafkaDataConsumer.release(internalConsumer) + } + + private case class NonCachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V]) + extends KafkaDataConsumer[K, V] { + override def release(): Unit = internalConsumer.close() + } + + // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap + private[kafka010] var cache: ju.Map[CacheKey, InternalKafkaConsumer[_, _]] = null + + /** + * Must be called before acquire, once per JVM, to configure the cache. + * Further calls are ignored. + */ + def init( + initialCapacity: Int, + maxCapacity: Int, + loadFactor: Float): Unit = synchronized { + if (null == cache) { + logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") + cache = new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer[_, _]]( + initialCapacity, loadFactor, true) { + override def removeEldestEntry( + entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer[_, _]]): Boolean = { + + // Try to remove the least-used entry if its currently not in use. + // + // If you cannot remove it, then the cache will keep growing. In the worst case, + // the cache will grow to the max number of concurrent tasks that can run in the executor, + // (that is, number of tasks slots) after which it will never reduce. This is unlikely to + // be a serious problem because an executor with more than 64 (default) tasks slots is + // likely running on a beefy machine that can handle a large number of simultaneously + // active consumers. + + if (entry.getValue.inUse == false && this.size > maxCapacity) { + logWarning( + s"KafkaConsumer cache hitting max capacity of $maxCapacity, " + + s"removing consumer for ${entry.getKey}") + try { + entry.getValue.close() + } catch { + case x: KafkaException => + logError("Error closing oldest Kafka consumer", x) + } + true + } else { + false + } + } + } + } + } + + /** + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]]. + * + * Note: This method guarantees that the consumer returned is not currently in use by anyone + * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by + * caching them and tracking when they are in use. + */ + def acquire[K, V]( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object], + context: TaskContext, + useCache: Boolean): KafkaDataConsumer[K, V] = synchronized { + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + val key = new CacheKey(groupId, topicPartition) + val existingInternalConsumer = cache.get(key) + + lazy val newInternalConsumer = new InternalKafkaConsumer[K, V](topicPartition, kafkaParams) + + if (context != null && context.attemptNumber >= 1) { + // If this is reattempt at running the task, then invalidate cached consumers if any and + // start with a new one. If prior attempt failures were cache related then this way old + // problematic consumers can be removed. + logDebug(s"Reattempt detected, invalidating cached consumer $existingInternalConsumer") + if (existingInternalConsumer != null) { + // Consumer exists in cache. If its in use, mark it for closing later, or close it now. + if (existingInternalConsumer.inUse) { + existingInternalConsumer.markedForClose = true + } else { + existingInternalConsumer.close() + // Remove the consumer from cache only if it's closed. + // Marked for close consumers will be removed in release function. + cache.remove(key) + } + } + + logDebug("Reattempt detected, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else if (!useCache) { + // If consumer reuse turned off, then do not use it, return a new consumer + logDebug("Cache usage turned off, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else if (existingInternalConsumer == null) { + // If consumer is not already cached, then put a new in the cache and return it + logDebug("No cached consumer, new cached consumer will be allocated " + + s"$newInternalConsumer") + cache.put(key, newInternalConsumer) + CachedKafkaDataConsumer(newInternalConsumer) + } else if (existingInternalConsumer.inUse) { + // If consumer is already cached but is currently in use, then return a new consumer + logDebug("Used cached consumer found, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else { + // If consumer is already cached and is currently not in use, then return that consumer + logDebug(s"Not used cached consumer found, re-using it $existingInternalConsumer") + existingInternalConsumer.inUse = true + // Any given TopicPartition should have a consistent key and value type + CachedKafkaDataConsumer(existingInternalConsumer.asInstanceOf[InternalKafkaConsumer[K, V]]) + } + } + + private def release(internalConsumer: InternalKafkaConsumer[_, _]): Unit = synchronized { + // Clear the consumer from the cache if this is indeed the consumer present in the cache + val key = new CacheKey(internalConsumer.groupId, internalConsumer.topicPartition) + val cachedInternalConsumer = cache.get(key) + if (internalConsumer.eq(cachedInternalConsumer)) { + // The released consumer is the same object as the cached one. + if (internalConsumer.markedForClose) { + internalConsumer.close() + cache.remove(key) + } else { + internalConsumer.inUse = false + } + } else { + // The released consumer is either not the same one as in the cache, or not in the cache + // at all. This may happen if the cache was invalidate while this consumer was being used. + // Just close this consumer. + internalConsumer.close() + logInfo(s"Released a supposedly cached consumer that was not found in the cache " + + s"$internalConsumer") + } + } +} + +private[kafka010] object InternalKafkaConsumer { + private val UNKNOWN_OFFSET = -2L +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 07239eda64d2e..3efc90fe466b2 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -19,8 +19,6 @@ package org.apache.spark.streaming.kafka010 import java.{ util => ju } -import scala.collection.mutable.ArrayBuffer - import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord } import org.apache.kafka.common.TopicPartition @@ -67,7 +65,7 @@ private[spark] class KafkaRDD[K, V]( // TODO is it necessary to have separate configs for initial poll time vs ongoing poll time? private val pollTimeout = conf.getLong("spark.streaming.kafka.consumer.poll.ms", - conf.getTimeAsMs("spark.network.timeout", "120s")) + conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L) private val cacheInitialCapacity = conf.getInt("spark.streaming.kafka.consumer.cache.initialCapacity", 16) private val cacheMaxCapacity = @@ -239,26 +237,18 @@ private class KafkaRDDIterator[K, V]( cacheLoadFactor: Float ) extends Iterator[ConsumerRecord[K, V]] { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - context.addTaskCompletionListener(_ => closeIfNeeded()) - val consumer = if (useConsumerCache) { - CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber >= 1) { - // just in case the prior attempt failures were cache related - CachedKafkaConsumer.remove(groupId, part.topic, part.partition) - } - CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) - } else { - CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + val consumer = { + KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) + KafkaDataConsumer.acquire[K, V](part.topicPartition(), kafkaParams, context, useConsumerCache) } var requestOffset = part.fromOffset def closeIfNeeded(): Unit = { - if (!useConsumerCache && consumer != null) { - consumer.close() + if (consumer != null) { + consumer.release() } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala new file mode 100644 index 0000000000000..d934c64962adb --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.util.concurrent.{Executors, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.kafka.clients.consumer.ConsumerConfig._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ + +class KafkaDataConsumerSuite extends SparkFunSuite with BeforeAndAfterAll { + private var testUtils: KafkaTestUtils = _ + private val topic = "topic" + Random.nextInt() + private val topicPartition = new TopicPartition(topic, 0) + private val groupId = "groupId" + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils + testUtils.setup() + KafkaDataConsumer.init(16, 64, 0.75f) + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + private def getKafkaParams() = Map[String, Object]( + GROUP_ID_CONFIG -> groupId, + BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ).asJava + + test("KafkaDataConsumer reuse in case of same groupId and TopicPartition") { + KafkaDataConsumer.cache.clear() + + val kafkaParams = getKafkaParams() + + val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, null, true) + consumer1.release() + + val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, null, true) + consumer2.release() + + assert(KafkaDataConsumer.cache.size() == 1) + val key = new CacheKey(groupId, topicPartition) + val existingInternalConsumer = KafkaDataConsumer.cache.get(key) + assert(existingInternalConsumer.eq(consumer1.internalConsumer)) + assert(existingInternalConsumer.eq(consumer2.internalConsumer)) + } + + test("concurrent use of KafkaDataConsumer") { + val data = (1 to 1000).map(_.toString) + testUtils.createTopic(topic) + testUtils.sendMessages(topic, data.toArray) + + val kafkaParams = getKafkaParams() + + val numThreads = 100 + val numConsumerUsages = 500 + + @volatile var error: Throwable = null + + def consume(i: Int): Unit = { + val useCache = Random.nextBoolean + val taskContext = if (Random.nextBoolean) { + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + } else { + null + } + val consumer = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, taskContext, useCache) + try { + val rcvd = (0 until data.length).map { offset => + val bytes = consumer.get(offset, 10000).value() + new String(bytes) + } + assert(rcvd == data) + } catch { + case e: Throwable => + error = e + throw e + } finally { + consumer.release() + } + } + + val threadPool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numConsumerUsages).map { i => + threadPool.submit(new Runnable { + override def run(): Unit = { consume(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + assert(error == null) + } finally { + threadPool.shutdown() + } + } +} diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 5ea52b6ad36a0..791cf0efaf888 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -191,6 +191,7 @@ class KafkaRDD[ private def fetchBatch: Iterator[MessageAndOffset] = { val req = new FetchRequestBuilder() + .clientId(consumer.clientId) .addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes) .build() val resp = consumer.fetch(req) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index fa0de6298a5f1..69c52365b1bf8 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -160,7 +160,6 @@ private[kinesis] class KinesisReceiver[T]( cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), workerId) .withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPosition.getPosition) .withTaskBackoffTimeMillis(500) .withRegionName(regionName) @@ -169,7 +168,8 @@ private[kinesis] class KinesisReceiver[T]( initialPosition match { case ts: AtTimestamp => baseClientLibConfiguration.withTimestampAtInitialPositionInStream(ts.getTimestamp) - case _ => baseClientLibConfiguration + case _ => + baseClientLibConfiguration.withInitialPositionInStream(initialPosition.getPosition) } } diff --git a/graphx/pom.xml b/graphx/pom.xml index fbe77fcb958d5..0f5dc548600b2 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -53,7 +53,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded com.google.guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index d76e84ed8c9ed..a559685b1633c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.HashSet import scala.language.existentials -import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor} -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor} +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.util.Utils diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 8e424b1c50236..2c39a7df0146e 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -38,7 +38,32 @@ hadoop-cloud + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.hadoop + hadoop-client + ${hadoop.version} + provided + + + + hadoop-3.1 + + + + org.apache.hadoop + hadoop-cloud-storage + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + org.codehaus.jackson + jackson-mapper-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.google.guava + guava + + + + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + + + org.eclipse.jetty + jetty-util-ajax + ${jetty.version} + ${hadoop.deps.scope} + + + + diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 1e34bb8c73279..d967aa39a4827 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -54,10 +55,12 @@ public static void main(String[] argsArray) throws Exception { String className = args.remove(0); boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); - AbstractCommandBuilder builder; + Map env = new HashMap<>(); + List cmd; if (className.equals("org.apache.spark.deploy.SparkSubmit")) { try { - builder = new SparkSubmitCommandBuilder(args); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(args); + cmd = buildCommand(builder, env, printLaunchCommand); } catch (IllegalArgumentException e) { printLaunchCommand = false; System.err.println("Error: " + e.getMessage()); @@ -76,17 +79,12 @@ public static void main(String[] argsArray) throws Exception { help.add(parser.className); } help.add(parser.USAGE_ERROR); - builder = new SparkSubmitCommandBuilder(help); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(help); + cmd = buildCommand(builder, env, printLaunchCommand); } } else { - builder = new SparkClassCommandBuilder(className, args); - } - - Map env = new HashMap<>(); - List cmd = builder.buildCommand(env); - if (printLaunchCommand) { - System.err.println("Spark Command: " + join(" ", cmd)); - System.err.println("========================================"); + AbstractCommandBuilder builder = new SparkClassCommandBuilder(className, args); + cmd = buildCommand(builder, env, printLaunchCommand); } if (isWindows()) { @@ -101,6 +99,22 @@ public static void main(String[] argsArray) throws Exception { } } + /** + * Prepare spark commands with the appropriate command builder. + * If printLaunchCommand is set then the commands will be printed to the stderr. + */ + private static List buildCommand( + AbstractCommandBuilder builder, + Map env, + boolean printLaunchCommand) throws IOException, IllegalArgumentException { + List cmd = builder.buildCommand(env); + if (printLaunchCommand) { + System.err.println("Spark Command: " + join(" ", cmd)); + System.err.println("========================================"); + } + return cmd; + } + /** * Prepare a command line for execution from a Windows batch script. * diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 5cb6457bf5c21..cc65f78b45c30 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -90,7 +90,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { final List userArgs; private final List parsedArgs; - private final boolean requiresAppResource; + // Special command means no appResource and no mainClass required + private final boolean isSpecialCommand; private final boolean isExample; /** @@ -105,7 +106,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { * spark-submit argument list to be modified after creation. */ SparkSubmitCommandBuilder() { - this.requiresAppResource = true; + this.isSpecialCommand = false; this.isExample = false; this.parsedArgs = new ArrayList<>(); this.userArgs = new ArrayList<>(); @@ -138,25 +139,26 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { case RUN_EXAMPLE: isExample = true; + appResource = SparkLauncher.NO_RESOURCE; submitArgs = args.subList(1, args.size()); } this.isExample = isExample; OptionParser parser = new OptionParser(true); parser.parse(submitArgs); - this.requiresAppResource = parser.requiresAppResource; + this.isSpecialCommand = parser.isSpecialCommand; } else { this.isExample = isExample; - this.requiresAppResource = false; + this.isSpecialCommand = true; } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL.equals(appResource) && requiresAppResource) { + if (PYSPARK_SHELL.equals(appResource) && !isSpecialCommand) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL.equals(appResource) && requiresAppResource) { + } else if (SPARKR_SHELL.equals(appResource) && !isSpecialCommand) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -166,18 +168,18 @@ public List buildCommand(Map env) List buildSparkSubmitArgs() { List args = new ArrayList<>(); OptionParser parser = new OptionParser(false); - final boolean requiresAppResource; + final boolean isSpecialCommand; // If the user args array is not empty, we need to parse it to detect exactly what // the user is trying to run, so that checks below are correct. if (!userArgs.isEmpty()) { parser.parse(userArgs); - requiresAppResource = parser.requiresAppResource; + isSpecialCommand = parser.isSpecialCommand; } else { - requiresAppResource = this.requiresAppResource; + isSpecialCommand = this.isSpecialCommand; } - if (!allowsMixedArguments && requiresAppResource) { + if (!allowsMixedArguments && !isSpecialCommand) { checkArgument(appResource != null, "Missing application resource."); } @@ -229,7 +231,7 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (isExample) { + if (isExample && !isSpecialCommand) { checkArgument(mainClass != null, "Missing example class name."); } @@ -421,7 +423,7 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean requiresAppResource = true; + boolean isSpecialCommand = false; private final boolean errorOnUnknownArgs; OptionParser(boolean errorOnUnknownArgs) { @@ -470,17 +472,14 @@ protected boolean handle(String opt, String value) { break; case KILL_SUBMISSION: case STATUS: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); parsedArgs.add(value); break; case HELP: case USAGE_ERROR: - requiresAppResource = false; - parsedArgs.add(opt); - break; case VERSION: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); break; default: diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 2e050f8413074..b343094b2e7b8 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.launcher; import java.io.File; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -27,7 +28,10 @@ import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; + import static org.junit.Assert.*; public class SparkSubmitCommandBuilderSuite extends BaseSuite { @@ -35,6 +39,9 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite { private static File dummyPropsFile; private static SparkSubmitOptionParser parser; + @Rule + public ExpectedException expectedException = ExpectedException.none(); + @BeforeClass public static void setUp() throws Exception { dummyPropsFile = File.createTempFile("spark", "properties"); @@ -74,8 +81,11 @@ public void testCliHelpAndNoArg() throws Exception { @Test public void testCliKillAndStatus() throws Exception { - testCLIOpts(parser.STATUS); - testCLIOpts(parser.KILL_SUBMISSION); + List params = Arrays.asList("driver-20160531171222-0000"); + testCLIOpts(null, parser.STATUS, params); + testCLIOpts(null, parser.KILL_SUBMISSION, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.STATUS, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.KILL_SUBMISSION, params); } @Test @@ -190,6 +200,33 @@ public void testSparkRShell() throws Exception { env.get("SPARKR_SUBMIT_ARGS")); } + @Test(expected = IllegalArgumentException.class) + public void testExamplesRunnerNoArg() throws Exception { + List sparkSubmitArgs = Arrays.asList(SparkSubmitCommandBuilder.RUN_EXAMPLE); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + + @Test + public void testExamplesRunnerNoMainClass() throws Exception { + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.HELP, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.USAGE_ERROR, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.VERSION, null); + } + + @Test + public void testExamplesRunnerWithMasterNoMainClass() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing example class name."); + + List sparkSubmitArgs = Arrays.asList( + SparkSubmitCommandBuilder.RUN_EXAMPLE, + parser.MASTER + "=foo" + ); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + @Test public void testExamplesRunner() throws Exception { List sparkSubmitArgs = Arrays.asList( @@ -344,10 +381,17 @@ private List buildCommand(List args, Map env) th return newCommandBuilder(args).buildCommand(env); } - private void testCLIOpts(String opt) throws Exception { - List helpArgs = Arrays.asList(opt, "driver-20160531171222-0000"); + private void testCLIOpts(String appResource, String opt, List params) throws Exception { + List args = new ArrayList<>(); + if (appResource != null) { + args.add(appResource); + } + args.add(opt); + if (params != null) { + args.addAll(params); + } Map env = new HashMap<>(); - List cmd = buildCommand(helpArgs, env); + List cmd = buildCommand(args, env); assertTrue(opt + " should be contained in the final cmd.", cmd.contains(opt)); } diff --git a/licenses/LICENSE-scopt.txt b/licenses-binary/LICENSE-AnchorJS.txt similarity index 100% rename from licenses/LICENSE-scopt.txt rename to licenses-binary/LICENSE-AnchorJS.txt diff --git a/licenses-binary/LICENSE-CC0.txt b/licenses-binary/LICENSE-CC0.txt new file mode 100644 index 0000000000000..1625c17936079 --- /dev/null +++ b/licenses-binary/LICENSE-CC0.txt @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/licenses/LICENSE-antlr.txt b/licenses-binary/LICENSE-antlr.txt similarity index 100% rename from licenses/LICENSE-antlr.txt rename to licenses-binary/LICENSE-antlr.txt diff --git a/licenses-binary/LICENSE-arpack.txt b/licenses-binary/LICENSE-arpack.txt new file mode 100644 index 0000000000000..a3ad80087bb63 --- /dev/null +++ b/licenses-binary/LICENSE-arpack.txt @@ -0,0 +1,8 @@ +Copyright © 2018 The University of Tennessee. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +· Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +· Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer listed in this license in the documentation and/or other materials provided with the distribution. +· Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +This software is provided by the copyright holders and contributors "as is" and any express or implied warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose are disclaimed. in no event shall the copyright owner or contributors be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute goods or services; loss of use, data, or profits; or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this software, even if advised of the possibility of such damage. \ No newline at end of file diff --git a/licenses-binary/LICENSE-automaton.txt b/licenses-binary/LICENSE-automaton.txt new file mode 100644 index 0000000000000..2fc6e8c3432f0 --- /dev/null +++ b/licenses-binary/LICENSE-automaton.txt @@ -0,0 +1,24 @@ +Copyright (c) 2001-2017 Anders Moeller +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +3. The name of the author may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-bootstrap.txt b/licenses-binary/LICENSE-bootstrap.txt new file mode 100644 index 0000000000000..6c711832fbc85 --- /dev/null +++ b/licenses-binary/LICENSE-bootstrap.txt @@ -0,0 +1,13 @@ +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/licenses-binary/LICENSE-bouncycastle-bcprov.txt b/licenses-binary/LICENSE-bouncycastle-bcprov.txt new file mode 100644 index 0000000000000..c445a93a06dd4 --- /dev/null +++ b/licenses-binary/LICENSE-bouncycastle-bcprov.txt @@ -0,0 +1,7 @@ +Copyright (c) 2000 - 2018 The Legion of the Bouncy Castle Inc. (https://www.bouncycastle.org) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-cloudpickle.txt b/licenses-binary/LICENSE-cloudpickle.txt new file mode 100644 index 0000000000000..b1e20fa1eda88 --- /dev/null +++ b/licenses-binary/LICENSE-cloudpickle.txt @@ -0,0 +1,28 @@ +Copyright (c) 2012, Regents of the University of California. +Copyright (c) 2009 `PiCloud, Inc. `_. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-d3.min.js.txt b/licenses-binary/LICENSE-d3.min.js.txt new file mode 100644 index 0000000000000..c71e3f254c068 --- /dev/null +++ b/licenses-binary/LICENSE-d3.min.js.txt @@ -0,0 +1,26 @@ +Copyright (c) 2010-2015, Michael Bostock +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* The name Michael Bostock may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-Mockito.txt b/licenses-binary/LICENSE-dagre-d3.txt similarity index 94% rename from licenses/LICENSE-Mockito.txt rename to licenses-binary/LICENSE-dagre-d3.txt index e0840a446caf5..4864fe05e9803 100644 --- a/licenses/LICENSE-Mockito.txt +++ b/licenses-binary/LICENSE-dagre-d3.txt @@ -1,6 +1,4 @@ -The MIT License - -Copyright (c) 2007 Mockito contributors +Copyright (c) 2013 Chris Pettitt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/licenses-binary/LICENSE-datatables.txt b/licenses-binary/LICENSE-datatables.txt new file mode 100644 index 0000000000000..bb7708b5b5a49 --- /dev/null +++ b/licenses-binary/LICENSE-datatables.txt @@ -0,0 +1,7 @@ +Copyright (C) 2008-2018, SpryMedia Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-f2j.txt b/licenses-binary/LICENSE-f2j.txt similarity index 100% rename from licenses/LICENSE-f2j.txt rename to licenses-binary/LICENSE-f2j.txt diff --git a/licenses-binary/LICENSE-graphlib-dot.txt b/licenses-binary/LICENSE-graphlib-dot.txt new file mode 100644 index 0000000000000..4864fe05e9803 --- /dev/null +++ b/licenses-binary/LICENSE-graphlib-dot.txt @@ -0,0 +1,19 @@ +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-heapq.txt b/licenses-binary/LICENSE-heapq.txt new file mode 100644 index 0000000000000..0c4c4b954bea4 --- /dev/null +++ b/licenses-binary/LICENSE-heapq.txt @@ -0,0 +1,280 @@ + +# A. HISTORY OF THE SOFTWARE +# ========================== +# +# Python was created in the early 1990s by Guido van Rossum at Stichting +# Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands +# as a successor of a language called ABC. Guido remains Python's +# principal author, although it includes many contributions from others. +# +# In 1995, Guido continued his work on Python at the Corporation for +# National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) +# in Reston, Virginia where he released several versions of the +# software. +# +# In May 2000, Guido and the Python core development team moved to +# BeOpen.com to form the BeOpen PythonLabs team. In October of the same +# year, the PythonLabs team moved to Digital Creations (now Zope +# Corporation, see http://www.zope.com). In 2001, the Python Software +# Foundation (PSF, see http://www.python.org/psf/) was formed, a +# non-profit organization created specifically to own Python-related +# Intellectual Property. Zope Corporation is a sponsoring member of +# the PSF. +# +# All Python releases are Open Source (see http://www.opensource.org for +# the Open Source Definition). Historically, most, but not all, Python +# releases have also been GPL-compatible; the table below summarizes +# the various releases. +# +# Release Derived Year Owner GPL- +# from compatible? (1) +# +# 0.9.0 thru 1.2 1991-1995 CWI yes +# 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes +# 1.6 1.5.2 2000 CNRI no +# 2.0 1.6 2000 BeOpen.com no +# 1.6.1 1.6 2001 CNRI yes (2) +# 2.1 2.0+1.6.1 2001 PSF no +# 2.0.1 2.0+1.6.1 2001 PSF yes +# 2.1.1 2.1+2.0.1 2001 PSF yes +# 2.2 2.1.1 2001 PSF yes +# 2.1.2 2.1.1 2002 PSF yes +# 2.1.3 2.1.2 2002 PSF yes +# 2.2.1 2.2 2002 PSF yes +# 2.2.2 2.2.1 2002 PSF yes +# 2.2.3 2.2.2 2003 PSF yes +# 2.3 2.2.2 2002-2003 PSF yes +# 2.3.1 2.3 2002-2003 PSF yes +# 2.3.2 2.3.1 2002-2003 PSF yes +# 2.3.3 2.3.2 2002-2003 PSF yes +# 2.3.4 2.3.3 2004 PSF yes +# 2.3.5 2.3.4 2005 PSF yes +# 2.4 2.3 2004 PSF yes +# 2.4.1 2.4 2005 PSF yes +# 2.4.2 2.4.1 2005 PSF yes +# 2.4.3 2.4.2 2006 PSF yes +# 2.4.4 2.4.3 2006 PSF yes +# 2.5 2.4 2006 PSF yes +# 2.5.1 2.5 2007 PSF yes +# 2.5.2 2.5.1 2008 PSF yes +# 2.5.3 2.5.2 2008 PSF yes +# 2.6 2.5 2008 PSF yes +# 2.6.1 2.6 2008 PSF yes +# 2.6.2 2.6.1 2009 PSF yes +# 2.6.3 2.6.2 2009 PSF yes +# 2.6.4 2.6.3 2009 PSF yes +# 2.6.5 2.6.4 2010 PSF yes +# 2.7 2.6 2010 PSF yes +# +# Footnotes: +# +# (1) GPL-compatible doesn't mean that we're distributing Python under +# the GPL. All Python licenses, unlike the GPL, let you distribute +# a modified version without making your changes open source. The +# GPL-compatible licenses make it possible to combine Python with +# other software that is released under the GPL; the others don't. +# +# (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, +# because its license has a choice of law clause. According to +# CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 +# is "not incompatible" with the GPL. +# +# Thanks to the many outside volunteers who have worked under Guido's +# direction to make these releases possible. +# +# +# B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON +# =============================================================== +# +# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# -------------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013 Python Software Foundation; All Rights Reserved" are retained +# in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 +# ------------------------------------------- +# +# BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 +# +# 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an +# office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the +# Individual or Organization ("Licensee") accessing and otherwise using +# this software in source or binary form and its associated +# documentation ("the Software"). +# +# 2. Subject to the terms and conditions of this BeOpen Python License +# Agreement, BeOpen hereby grants Licensee a non-exclusive, +# royalty-free, world-wide license to reproduce, analyze, test, perform +# and/or display publicly, prepare derivative works, distribute, and +# otherwise use the Software alone or in any derivative version, +# provided, however, that the BeOpen Python License is retained in the +# Software, alone or in any derivative version prepared by Licensee. +# +# 3. BeOpen is making the Software available to Licensee on an "AS IS" +# basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE +# SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS +# AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY +# DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 5. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 6. This License Agreement shall be governed by and interpreted in all +# respects by the law of the State of California, excluding conflict of +# law provisions. Nothing in this License Agreement shall be deemed to +# create any relationship of agency, partnership, or joint venture +# between BeOpen and Licensee. This License Agreement does not grant +# permission to use BeOpen trademarks or trade names in a trademark +# sense to endorse or promote products or services of Licensee, or any +# third party. As an exception, the "BeOpen Python" logos available at +# http://www.pythonlabs.com/logos.html may be used according to the +# permissions granted on that web page. +# +# 7. By copying, installing or otherwise using the software, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 +# --------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Corporation for National +# Research Initiatives, having an office at 1895 Preston White Drive, +# Reston, VA 20191 ("CNRI"), and the Individual or Organization +# ("Licensee") accessing and otherwise using Python 1.6.1 software in +# source or binary form and its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, CNRI +# hereby grants Licensee a nonexclusive, royalty-free, world-wide +# license to reproduce, analyze, test, perform and/or display publicly, +# prepare derivative works, distribute, and otherwise use Python 1.6.1 +# alone or in any derivative version, provided, however, that CNRI's +# License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) +# 1995-2001 Corporation for National Research Initiatives; All Rights +# Reserved" are retained in Python 1.6.1 alone or in any derivative +# version prepared by Licensee. Alternately, in lieu of CNRI's License +# Agreement, Licensee may substitute the following text (omitting the +# quotes): "Python 1.6.1 is made available subject to the terms and +# conditions in CNRI's License Agreement. This Agreement together with +# Python 1.6.1 may be located on the Internet using the following +# unique, persistent identifier (known as a handle): 1895.22/1013. This +# Agreement may also be obtained from a proxy server on the Internet +# using the following URL: http://hdl.handle.net/1895.22/1013". +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python 1.6.1 or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python 1.6.1. +# +# 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" +# basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. This License Agreement shall be governed by the federal +# intellectual property law of the United States, including without +# limitation the federal copyright law, and, to the extent such +# U.S. federal law does not apply, by the law of the Commonwealth of +# Virginia, excluding Virginia's conflict of law provisions. +# Notwithstanding the foregoing, with regard to derivative works based +# on Python 1.6.1 that incorporate non-separable material that was +# previously distributed under the GNU General Public License (GPL), the +# law of the Commonwealth of Virginia shall govern this License +# Agreement only as to issues arising under or with respect to +# Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this +# License Agreement shall be deemed to create any relationship of +# agency, partnership, or joint venture between CNRI and Licensee. This +# License Agreement does not grant permission to use CNRI trademarks or +# trade name in a trademark sense to endorse or promote products or +# services of Licensee, or any third party. +# +# 8. By clicking on the "ACCEPT" button where indicated, or by copying, +# installing or otherwise using Python 1.6.1, Licensee agrees to be +# bound by the terms and conditions of this License Agreement. +# +# ACCEPT +# +# +# CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 +# -------------------------------------------------- +# +# Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, +# The Netherlands. All rights reserved. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation, and that the name of Stichting Mathematisch +# Centrum or CWI not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO +# THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE +# FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-janino.txt b/licenses-binary/LICENSE-janino.txt new file mode 100644 index 0000000000000..d1e1f237c4641 --- /dev/null +++ b/licenses-binary/LICENSE-janino.txt @@ -0,0 +1,31 @@ +Janino - An embedded Java[TM] compiler + +Copyright (c) 2001-2016, Arno Unkrig +Copyright (c) 2015-2016 TIBCO Software Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials + provided with the distribution. + 3. Neither the name of JANINO nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER +IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN +IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-javassist.html b/licenses-binary/LICENSE-javassist.html new file mode 100644 index 0000000000000..5abd563a0c4d9 --- /dev/null +++ b/licenses-binary/LICENSE-javassist.html @@ -0,0 +1,373 @@ + + + Javassist License + + + + +
    MOZILLA PUBLIC LICENSE
    Version + 1.1 +

    +


    +
    +

    1. Definitions. +

      1.0.1. "Commercial Use" means distribution or otherwise making the + Covered Code available to a third party. +

      1.1. ''Contributor'' means each entity that creates or contributes + to the creation of Modifications. +

      1.2. ''Contributor Version'' means the combination of the Original + Code, prior Modifications used by a Contributor, and the Modifications made by + that particular Contributor. +

      1.3. ''Covered Code'' means the Original Code or Modifications or + the combination of the Original Code and Modifications, in each case including + portions thereof. +

      1.4. ''Electronic Distribution Mechanism'' means a mechanism + generally accepted in the software development community for the electronic + transfer of data. +

      1.5. ''Executable'' means Covered Code in any form other than Source + Code. +

      1.6. ''Initial Developer'' means the individual or entity identified + as the Initial Developer in the Source Code notice required by Exhibit + A. +

      1.7. ''Larger Work'' means a work which combines Covered Code or + portions thereof with code not governed by the terms of this License. +

      1.8. ''License'' means this document. +

      1.8.1. "Licensable" means having the right to grant, to the maximum + extent possible, whether at the time of the initial grant or subsequently + acquired, any and all of the rights conveyed herein. +

      1.9. ''Modifications'' means any addition to or deletion from the + substance or structure of either the Original Code or any previous + Modifications. When Covered Code is released as a series of files, a + Modification is: +

        A. Any addition to or deletion from the contents of a file + containing Original Code or previous Modifications. +

        B. Any new file that contains any part of the Original Code or + previous Modifications.
         

      1.10. ''Original Code'' +means Source Code of computer software code which is described in the Source +Code notice required by Exhibit A as Original Code, and which, at the +time of its release under this License is not already Covered Code governed by +this License. +

      1.10.1. "Patent Claims" means any patent claim(s), now owned or + hereafter acquired, including without limitation,  method, process, and + apparatus claims, in any patent Licensable by grantor. +

      1.11. ''Source Code'' means the preferred form of the Covered Code + for making modifications to it, including all modules it contains, plus any + associated interface definition files, scripts used to control compilation and + installation of an Executable, or source code differential comparisons against + either the Original Code or another well known, available Covered Code of the + Contributor's choice. The Source Code can be in a compressed or archival form, + provided the appropriate decompression or de-archiving software is widely + available for no charge. +

      1.12. "You'' (or "Your")  means an individual or a legal entity + exercising rights under, and complying with all of the terms of, this License + or a future version of this License issued under Section 6.1. For legal + entities, "You'' includes any entity which controls, is controlled by, or is + under common control with You. For purposes of this definition, "control'' + means (a) the power, direct or indirect, to cause the direction or management + of such entity, whether by contract or otherwise, or (b) ownership of more + than fifty percent (50%) of the outstanding shares or beneficial ownership of + such entity.

    2. Source Code License. +
      2.1. The Initial Developer Grant.
      The Initial Developer hereby + grants You a world-wide, royalty-free, non-exclusive license, subject to third + party intellectual property claims: +
        (a)  under intellectual property rights (other than + patent or trademark) Licensable by Initial Developer to use, reproduce, + modify, display, perform, sublicense and distribute the Original Code (or + portions thereof) with or without Modifications, and/or as part of a Larger + Work; and +

        (b) under Patents Claims infringed by the making, using or selling + of Original Code, to make, have made, use, practice, sell, and offer for + sale, and/or otherwise dispose of the Original Code (or portions thereof). +

          +
          (c) the licenses granted in this Section 2.1(a) and (b) + are effective on the date Initial Developer first distributes Original Code + under the terms of this License. +

          (d) Notwithstanding Section 2.1(b) above, no patent license is + granted: 1) for code that You delete from the Original Code; 2) separate + from the Original Code;  or 3) for infringements caused by: i) the + modification of the Original Code or ii) the combination of the Original + Code with other software or devices.
           

        2.2. Contributor + Grant.
        Subject to third party intellectual property claims, each + Contributor hereby grants You a world-wide, royalty-free, non-exclusive + license +

          (a)  under intellectual property rights (other + than patent or trademark) Licensable by Contributor, to use, reproduce, + modify, display, perform, sublicense and distribute the Modifications + created by such Contributor (or portions thereof) either on an unmodified + basis, with other Modifications, as Covered Code and/or as part of a Larger + Work; and +

          (b) under Patent Claims infringed by the making, using, or selling + of  Modifications made by that Contributor either alone and/or in combination with its Contributor Version (or portions of such + combination), to make, use, sell, offer for sale, have made, and/or + otherwise dispose of: 1) Modifications made by that Contributor (or portions + thereof); and 2) the combination of  Modifications made by that + Contributor with its Contributor Version (or portions of such + combination). +

          (c) the licenses granted in Sections 2.2(a) and 2.2(b) are + effective on the date Contributor first makes Commercial Use of the Covered + Code. +

          (d)    Notwithstanding Section 2.2(b) above, no + patent license is granted: 1) for any code that Contributor has deleted from + the Contributor Version; 2)  separate from the Contributor + Version;  3)  for infringements caused by: i) third party + modifications of Contributor Version or ii)  the combination of + Modifications made by that Contributor with other software  (except as + part of the Contributor Version) or other devices; or 4) under Patent Claims + infringed by Covered Code in the absence of Modifications made by that + Contributor.

      +


      3. Distribution Obligations. +

        3.1. Application of License.
        The Modifications which You create + or to which You contribute are governed by the terms of this License, + including without limitation Section 2.2. The Source Code version of + Covered Code may be distributed only under the terms of this License or a + future version of this License released under Section 6.1, and You must + include a copy of this License with every copy of the Source Code You + distribute. You may not offer or impose any terms on any Source Code version + that alters or restricts the applicable version of this License or the + recipients' rights hereunder. However, You may include an additional document + offering the additional rights described in Section 3.5. +

        3.2. Availability of Source Code.
        Any Modification which You + create or to which You contribute must be made available in Source Code form + under the terms of this License either on the same media as an Executable + version or via an accepted Electronic Distribution Mechanism to anyone to whom + you made an Executable version available; and if made available via Electronic + Distribution Mechanism, must remain available for at least twelve (12) months + after the date it initially became available, or at least six (6) months after + a subsequent version of that particular Modification has been made available + to such recipients. You are responsible for ensuring that the Source Code + version remains available even if the Electronic Distribution Mechanism is + maintained by a third party. +

        3.3. Description of Modifications.
        You must cause all Covered + Code to which You contribute to contain a file documenting the changes You + made to create that Covered Code and the date of any change. You must include + a prominent statement that the Modification is derived, directly or + indirectly, from Original Code provided by the Initial Developer and including + the name of the Initial Developer in (a) the Source Code, and (b) in any + notice in an Executable version or related documentation in which You describe + the origin or ownership of the Covered Code. +

        3.4. Intellectual Property Matters +

          (a) Third Party Claims.
          If Contributor has knowledge that a + license under a third party's intellectual property rights is required to + exercise the rights granted by such Contributor under Sections 2.1 or 2.2, + Contributor must include a text file with the Source Code distribution + titled "LEGAL'' which describes the claim and the party making the claim in + sufficient detail that a recipient will know whom to contact. If Contributor + obtains such knowledge after the Modification is made available as described + in Section 3.2, Contributor shall promptly modify the LEGAL file in all + copies Contributor makes available thereafter and shall take other steps + (such as notifying appropriate mailing lists or newsgroups) reasonably + calculated to inform those who received the Covered Code that new knowledge + has been obtained. +

          (b) Contributor APIs.
          If Contributor's Modifications include + an application programming interface and Contributor has knowledge of patent + licenses which are reasonably necessary to implement that API, Contributor + must also include this information in the LEGAL file. +
           

                  +(c)    Representations. +
          Contributor represents that, except as disclosed pursuant to Section + 3.4(a) above, Contributor believes that Contributor's Modifications are + Contributor's original creation(s) and/or Contributor has sufficient rights + to grant the rights conveyed by this License.
        +


        3.5. Required Notices.
        You must duplicate the notice in + Exhibit A in each file of the Source Code.  If it is not possible + to put such notice in a particular Source Code file due to its structure, then + You must include such notice in a location (such as a relevant directory) + where a user would be likely to look for such a notice.  If You created + one or more Modification(s) You may add your name as a Contributor to the + notice described in Exhibit A.  You must also duplicate this + License in any documentation for the Source Code where You describe + recipients' rights or ownership rights relating to Covered Code.  You may + choose to offer, and to charge a fee for, warranty, support, indemnity or + liability obligations to one or more recipients of Covered Code. However, You + may do so only on Your own behalf, and not on behalf of the Initial Developer + or any Contributor. You must make it absolutely clear than any such warranty, + support, indemnity or liability obligation is offered by You alone, and You + hereby agree to indemnify the Initial Developer and every Contributor for any + liability incurred by the Initial Developer or such Contributor as a result of + warranty, support, indemnity or liability terms You offer. +

        3.6. Distribution of Executable Versions.
        You may distribute + Covered Code in Executable form only if the requirements of Section + 3.1-3.5 have been met for that Covered Code, and if You include a + notice stating that the Source Code version of the Covered Code is available + under the terms of this License, including a description of how and where You + have fulfilled the obligations of Section 3.2. The notice must be + conspicuously included in any notice in an Executable version, related + documentation or collateral in which You describe recipients' rights relating + to the Covered Code. You may distribute the Executable version of Covered Code + or ownership rights under a license of Your choice, which may contain terms + different from this License, provided that You are in compliance with the + terms of this License and that the license for the Executable version does not + attempt to limit or alter the recipient's rights in the Source Code version + from the rights set forth in this License. If You distribute the Executable + version under a different license You must make it absolutely clear that any + terms which differ from this License are offered by You alone, not by the + Initial Developer or any Contributor. You hereby agree to indemnify the + Initial Developer and every Contributor for any liability incurred by the + Initial Developer or such Contributor as a result of any such terms You offer. + +

        3.7. Larger Works.
        You may create a Larger Work by combining + Covered Code with other code not governed by the terms of this License and + distribute the Larger Work as a single product. In such a case, You must make + sure the requirements of this License are fulfilled for the Covered + Code.

      4. Inability to Comply Due to Statute or Regulation. +
        If it is impossible for You to comply with any of the terms of this + License with respect to some or all of the Covered Code due to statute, + judicial order, or regulation then You must: (a) comply with the terms of this + License to the maximum extent possible; and (b) describe the limitations and + the code they affect. Such description must be included in the LEGAL file + described in Section 3.4 and must be included with all distributions of + the Source Code. Except to the extent prohibited by statute or regulation, + such description must be sufficiently detailed for a recipient of ordinary + skill to be able to understand it.
      5. Application of this License. +
        This License applies to code to which the Initial Developer has attached + the notice in Exhibit A and to related Covered Code.
      6. Versions + of the License. +
        6.1. New Versions.
        Netscape Communications Corporation + (''Netscape'') may publish revised and/or new versions of the License from + time to time. Each version will be given a distinguishing version number. +

        6.2. Effect of New Versions.
        Once Covered Code has been + published under a particular version of the License, You may always continue + to use it under the terms of that version. You may also choose to use such + Covered Code under the terms of any subsequent version of the License + published by Netscape. No one other than Netscape has the right to modify the + terms applicable to Covered Code created under this License. +

        6.3. Derivative Works.
        If You create or use a modified version + of this License (which you may only do in order to apply it to code which is + not already Covered Code governed by this License), You must (a) rename Your + license so that the phrases ''Mozilla'', ''MOZILLAPL'', ''MOZPL'', + ''Netscape'', "MPL", ''NPL'' or any confusingly similar phrase do not appear + in your license (except to note that your license differs from this License) + and (b) otherwise make it clear that Your version of the license contains + terms which differ from the Mozilla Public License and Netscape Public + License. (Filling in the name of the Initial Developer, Original Code or + Contributor in the notice described in Exhibit A shall not of + themselves be deemed to be modifications of this License.)

      7. + DISCLAIMER OF WARRANTY. +
        COVERED CODE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS'' BASIS, WITHOUT + WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, WITHOUT + LIMITATION, WARRANTIES THAT THE COVERED CODE IS FREE OF DEFECTS, MERCHANTABLE, + FIT FOR A PARTICULAR PURPOSE OR NON-INFRINGING. THE ENTIRE RISK AS TO THE + QUALITY AND PERFORMANCE OF THE COVERED CODE IS WITH YOU. SHOULD ANY COVERED + CODE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL DEVELOPER OR ANY + OTHER CONTRIBUTOR) ASSUME THE COST OF ANY NECESSARY SERVICING, REPAIR OR + CORRECTION. THIS DISCLAIMER OF WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS + LICENSE. NO USE OF ANY COVERED CODE IS AUTHORIZED HEREUNDER EXCEPT UNDER THIS + DISCLAIMER.
      8. TERMINATION. +
        8.1.  This License and the rights granted hereunder will + terminate automatically if You fail to comply with terms herein and fail to + cure such breach within 30 days of becoming aware of the breach. All + sublicenses to the Covered Code which are properly granted shall survive any + termination of this License. Provisions which, by their nature, must remain in + effect beyond the termination of this License shall survive. +

        8.2.  If You initiate litigation by asserting a patent + infringement claim (excluding declatory judgment actions) against Initial + Developer or a Contributor (the Initial Developer or Contributor against whom + You file such action is referred to as "Participant")  alleging that: +

        (a)  such Participant's Contributor Version directly or + indirectly infringes any patent, then any and all rights granted by such + Participant to You under Sections 2.1 and/or 2.2 of this License shall, upon + 60 days notice from Participant terminate prospectively, unless if within 60 + days after receipt of notice You either: (i)  agree in writing to pay + Participant a mutually agreeable reasonable royalty for Your past and future + use of Modifications made by such Participant, or (ii) withdraw Your + litigation claim with respect to the Contributor Version against such + Participant.  If within 60 days of notice, a reasonable royalty and + payment arrangement are not mutually agreed upon in writing by the parties or + the litigation claim is not withdrawn, the rights granted by Participant to + You under Sections 2.1 and/or 2.2 automatically terminate at the expiration of + the 60 day notice period specified above. +

        (b)  any software, hardware, or device, other than such + Participant's Contributor Version, directly or indirectly infringes any + patent, then any rights granted to You by such Participant under Sections + 2.1(b) and 2.2(b) are revoked effective as of the date You first made, used, + sold, distributed, or had made, Modifications made by that Participant. +

        8.3.  If You assert a patent infringement claim against + Participant alleging that such Participant's Contributor Version directly or + indirectly infringes any patent where such claim is resolved (such as by + license or settlement) prior to the initiation of patent infringement + litigation, then the reasonable value of the licenses granted by such + Participant under Sections 2.1 or 2.2 shall be taken into account in + determining the amount or value of any payment or license. +

        8.4.  In the event of termination under Sections 8.1 or 8.2 + above,  all end user license agreements (excluding distributors and + resellers) which have been validly granted by You or any distributor hereunder + prior to termination shall survive termination.

      9. LIMITATION OF + LIABILITY. +
        UNDER NO CIRCUMSTANCES AND UNDER NO LEGAL THEORY, WHETHER TORT (INCLUDING + NEGLIGENCE), CONTRACT, OR OTHERWISE, SHALL YOU, THE INITIAL DEVELOPER, ANY + OTHER CONTRIBUTOR, OR ANY DISTRIBUTOR OF COVERED CODE, OR ANY SUPPLIER OF ANY + OF SUCH PARTIES, BE LIABLE TO ANY PERSON FOR ANY INDIRECT, SPECIAL, + INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY CHARACTER INCLUDING, WITHOUT + LIMITATION, DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, COMPUTER FAILURE OR + MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES, EVEN IF SUCH + PARTY SHALL HAVE BEEN INFORMED OF THE POSSIBILITY OF SUCH DAMAGES. THIS + LIMITATION OF LIABILITY SHALL NOT APPLY TO LIABILITY FOR DEATH OR PERSONAL + INJURY RESULTING FROM SUCH PARTY'S NEGLIGENCE TO THE EXTENT APPLICABLE LAW + PROHIBITS SUCH LIMITATION. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OR + LIMITATION OF INCIDENTAL OR CONSEQUENTIAL DAMAGES, SO THIS EXCLUSION AND + LIMITATION MAY NOT APPLY TO YOU.
      10. U.S. GOVERNMENT END USERS. +
        The Covered Code is a ''commercial item,'' as that term is defined in 48 + C.F.R. 2.101 (Oct. 1995), consisting of ''commercial computer software'' and + ''commercial computer software documentation,'' as such terms are used in 48 + C.F.R. 12.212 (Sept. 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. + 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users + acquire Covered Code with only those rights set forth herein.
      11. + MISCELLANEOUS. +
        This License represents the complete agreement concerning subject matter + hereof. If any provision of this License is held to be unenforceable, such + provision shall be reformed only to the extent necessary to make it + enforceable. This License shall be governed by California law provisions + (except to the extent applicable law, if any, provides otherwise), excluding + its conflict-of-law provisions. With respect to disputes in which at least one + party is a citizen of, or an entity chartered or registered to do business in + the United States of America, any litigation relating to this License shall be + subject to the jurisdiction of the Federal Courts of the Northern District of + California, with venue lying in Santa Clara County, California, with the + losing party responsible for costs, including without limitation, court costs + and reasonable attorneys' fees and expenses. The application of the United + Nations Convention on Contracts for the International Sale of Goods is + expressly excluded. Any law or regulation which provides that the language of + a contract shall be construed against the drafter shall not apply to this + License.
      12. RESPONSIBILITY FOR CLAIMS. +
        As between Initial Developer and the Contributors, each party is + responsible for claims and damages arising, directly or indirectly, out of its + utilization of rights under this License and You agree to work with Initial + Developer and Contributors to distribute such responsibility on an equitable + basis. Nothing herein is intended or shall be deemed to constitute any + admission of liability.
      13. MULTIPLE-LICENSED CODE. +
        Initial Developer may designate portions of the Covered Code as + "Multiple-Licensed".  "Multiple-Licensed" means that the Initial + Developer permits you to utilize portions of the Covered Code under Your + choice of the MPL or the alternative licenses, if any, specified by the + Initial Developer in the file described in Exhibit A.
      +


      EXHIBIT A -Mozilla Public License. +

        The contents of this file are subject to the Mozilla Public License + Version 1.1 (the "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at +
        http://www.mozilla.org/MPL/ +

        Software distributed under the License is distributed on an "AS IS" basis, + WITHOUT WARRANTY OF
        ANY KIND, either express or implied. See the License + for the specific language governing rights and
        limitations under the + License. +

        The Original Code is Javassist. +

        The Initial Developer of the Original Code is Shigeru Chiba. + Portions created by the Initial Developer are
          + Copyright (C) 1999- Shigeru Chiba. All Rights Reserved. +

        Contributor(s): __Bill Burke, Jason T. Greene______________. + +

        Alternatively, the contents of this software may be used under the + terms of the GNU Lesser General Public License Version 2.1 or later + (the "LGPL"), or the Apache License Version 2.0 (the "AL"), + in which case the provisions of the LGPL or the AL are applicable + instead of those above. If you wish to allow use of your version of + this software only under the terms of either the LGPL or the AL, and not to allow others to + use your version of this software under the terms of the MPL, indicate + your decision by deleting the provisions above and replace them with + the notice and other provisions required by the LGPL or the AL. If you do not + delete the provisions above, a recipient may use your version of this + software under the terms of any one of the MPL, the LGPL or the AL. + +

      + + \ No newline at end of file diff --git a/licenses/LICENSE-javolution.txt b/licenses-binary/LICENSE-javolution.txt similarity index 100% rename from licenses/LICENSE-javolution.txt rename to licenses-binary/LICENSE-javolution.txt diff --git a/licenses/LICENSE-jline.txt b/licenses-binary/LICENSE-jline.txt similarity index 100% rename from licenses/LICENSE-jline.txt rename to licenses-binary/LICENSE-jline.txt diff --git a/licenses/LICENSE-junit-interface.txt b/licenses-binary/LICENSE-jodd.txt similarity index 69% rename from licenses/LICENSE-junit-interface.txt rename to licenses-binary/LICENSE-jodd.txt index e835350c4e2a4..cc6b458adb386 100644 --- a/licenses/LICENSE-junit-interface.txt +++ b/licenses-binary/LICENSE-jodd.txt @@ -1,15 +1,15 @@ -Copyright (c) 2009-2012, Stefan Zeiger +Copyright (c) 2003-present, Jodd Team (https://jodd.org) All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. +1. Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE diff --git a/licenses/LICENSE-DPark.txt b/licenses-binary/LICENSE-join.txt similarity index 100% rename from licenses/LICENSE-DPark.txt rename to licenses-binary/LICENSE-join.txt diff --git a/licenses-binary/LICENSE-jquery.txt b/licenses-binary/LICENSE-jquery.txt new file mode 100644 index 0000000000000..45930542204fb --- /dev/null +++ b/licenses-binary/LICENSE-jquery.txt @@ -0,0 +1,20 @@ +Copyright JS Foundation and other contributors, https://js.foundation/ + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-json-formatter.txt b/licenses-binary/LICENSE-json-formatter.txt new file mode 100644 index 0000000000000..5193348fce126 --- /dev/null +++ b/licenses-binary/LICENSE-json-formatter.txt @@ -0,0 +1,6 @@ +Copyright 2014 Mohsen Azimi + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/licenses-binary/LICENSE-jtransforms.html b/licenses-binary/LICENSE-jtransforms.html new file mode 100644 index 0000000000000..351c17412357b --- /dev/null +++ b/licenses-binary/LICENSE-jtransforms.html @@ -0,0 +1,388 @@ + + +Mozilla Public License version 1.1 + + + + +

      Mozilla Public License Version 1.1

      +

      1. Definitions.

      +
      +
      1.0.1. "Commercial Use" +
      means distribution or otherwise making the Covered Code available to a third party. +
      1.1. "Contributor" +
      means each entity that creates or contributes to the creation of Modifications. +
      1.2. "Contributor Version" +
      means the combination of the Original Code, prior Modifications used by a Contributor, + and the Modifications made by that particular Contributor. +
      1.3. "Covered Code" +
      means the Original Code or Modifications or the combination of the Original Code and + Modifications, in each case including portions thereof. +
      1.4. "Electronic Distribution Mechanism" +
      means a mechanism generally accepted in the software development community for the + electronic transfer of data. +
      1.5. "Executable" +
      means Covered Code in any form other than Source Code. +
      1.6. "Initial Developer" +
      means the individual or entity identified as the Initial Developer in the Source Code + notice required by Exhibit A. +
      1.7. "Larger Work" +
      means a work which combines Covered Code or portions thereof with code not governed + by the terms of this License. +
      1.8. "License" +
      means this document. +
      1.8.1. "Licensable" +
      means having the right to grant, to the maximum extent possible, whether at the + time of the initial grant or subsequently acquired, any and all of the rights + conveyed herein. +
      1.9. "Modifications" +
      +

      means any addition to or deletion from the substance or structure of either the + Original Code or any previous Modifications. When Covered Code is released as a + series of files, a Modification is: +

        +
      1. Any addition to or deletion from the contents of a file + containing Original Code or previous Modifications. +
      2. Any new file that contains any part of the Original Code or + previous Modifications. +
      +
      1.10. "Original Code" +
      means Source Code of computer software code which is described in the Source Code + notice required by Exhibit A as Original Code, and which, + at the time of its release under this License is not already Covered Code governed + by this License. +
      1.10.1. "Patent Claims" +
      means any patent claim(s), now owned or hereafter acquired, including without + limitation, method, process, and apparatus claims, in any patent Licensable by + grantor. +
      1.11. "Source Code" +
      means the preferred form of the Covered Code for making modifications to it, + including all modules it contains, plus any associated interface definition files, + scripts used to control compilation and installation of an Executable, or source + code differential comparisons against either the Original Code or another well known, + available Covered Code of the Contributor's choice. The Source Code can be in a + compressed or archival form, provided the appropriate decompression or de-archiving + software is widely available for no charge. +
      1.12. "You" (or "Your") +
      means an individual or a legal entity exercising rights under, and complying with + all of the terms of, this License or a future version of this License issued under + Section 6.1. For legal entities, "You" includes any entity + which controls, is controlled by, or is under common control with You. For purposes of + this definition, "control" means (a) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or otherwise, or (b) + ownership of more than fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. +
      +

      2. Source Code License.

      +

      2.1. The Initial Developer Grant.

      +

      The Initial Developer hereby grants You a world-wide, royalty-free, non-exclusive + license, subject to third party intellectual property claims: +

        +
      1. under intellectual property rights (other than patent or + trademark) Licensable by Initial Developer to use, reproduce, modify, display, perform, + sublicense and distribute the Original Code (or portions thereof) with or without + Modifications, and/or as part of a Larger Work; and +
      2. under Patents Claims infringed by the making, using or selling + of Original Code, to make, have made, use, practice, sell, and offer for sale, and/or + otherwise dispose of the Original Code (or portions thereof). +
      3. the licenses granted in this Section 2.1 + (a) and (b) are effective on + the date Initial Developer first distributes Original Code under the terms of this + License. +
      4. Notwithstanding Section 2.1 (b) + above, no patent license is granted: 1) for code that You delete from the Original Code; + 2) separate from the Original Code; or 3) for infringements caused by: i) the + modification of the Original Code or ii) the combination of the Original Code with other + software or devices. +
      +

      2.2. Contributor Grant.

      +

      Subject to third party intellectual property claims, each Contributor hereby grants You + a world-wide, royalty-free, non-exclusive license +

        +
      1. under intellectual property rights (other than patent or trademark) + Licensable by Contributor, to use, reproduce, modify, display, perform, sublicense and + distribute the Modifications created by such Contributor (or portions thereof) either on + an unmodified basis, with other Modifications, as Covered Code and/or as part of a Larger + Work; and +
      2. under Patent Claims infringed by the making, using, or selling of + Modifications made by that Contributor either alone and/or in combination with its + Contributor Version (or portions of such combination), to make, use, sell, offer for + sale, have made, and/or otherwise dispose of: 1) Modifications made by that Contributor + (or portions thereof); and 2) the combination of Modifications made by that Contributor + with its Contributor Version (or portions of such combination). +
      3. the licenses granted in Sections 2.2 + (a) and 2.2 (b) are effective + on the date Contributor first makes Commercial Use of the Covered Code. +
      4. Notwithstanding Section 2.2 (b) + above, no patent license is granted: 1) for any code that Contributor has deleted from + the Contributor Version; 2) separate from the Contributor Version; 3) for infringements + caused by: i) third party modifications of Contributor Version or ii) the combination of + Modifications made by that Contributor with other software (except as part of the + Contributor Version) or other devices; or 4) under Patent Claims infringed by Covered Code + in the absence of Modifications made by that Contributor. +
      +

      3. Distribution Obligations.

      +

      3.1. Application of License.

      +

      The Modifications which You create or to which You contribute are governed by the terms + of this License, including without limitation Section 2.2. The + Source Code version of Covered Code may be distributed only under the terms of this License + or a future version of this License released under Section 6.1, + and You must include a copy of this License with every copy of the Source Code You + distribute. You may not offer or impose any terms on any Source Code version that alters or + restricts the applicable version of this License or the recipients' rights hereunder. + However, You may include an additional document offering the additional rights described in + Section 3.5. +

      3.2. Availability of Source Code.

      +

      Any Modification which You create or to which You contribute must be made available in + Source Code form under the terms of this License either on the same media as an Executable + version or via an accepted Electronic Distribution Mechanism to anyone to whom you made an + Executable version available; and if made available via Electronic Distribution Mechanism, + must remain available for at least twelve (12) months after the date it initially became + available, or at least six (6) months after a subsequent version of that particular + Modification has been made available to such recipients. You are responsible for ensuring + that the Source Code version remains available even if the Electronic Distribution + Mechanism is maintained by a third party. +

      3.3. Description of Modifications.

      +

      You must cause all Covered Code to which You contribute to contain a file documenting the + changes You made to create that Covered Code and the date of any change. You must include a + prominent statement that the Modification is derived, directly or indirectly, from Original + Code provided by the Initial Developer and including the name of the Initial Developer in + (a) the Source Code, and (b) in any notice in an Executable version or related documentation + in which You describe the origin or ownership of the Covered Code. +

      3.4. Intellectual Property Matters

      +

      (a) Third Party Claims

      +

      If Contributor has knowledge that a license under a third party's intellectual property + rights is required to exercise the rights granted by such Contributor under Sections + 2.1 or 2.2, Contributor must include a + text file with the Source Code distribution titled "LEGAL" which describes the claim and the + party making the claim in sufficient detail that a recipient will know whom to contact. If + Contributor obtains such knowledge after the Modification is made available as described in + Section 3.2, Contributor shall promptly modify the LEGAL file in + all copies Contributor makes available thereafter and shall take other steps (such as + notifying appropriate mailing lists or newsgroups) reasonably calculated to inform those who + received the Covered Code that new knowledge has been obtained. +

      (b) Contributor APIs

      +

      If Contributor's Modifications include an application programming interface and Contributor + has knowledge of patent licenses which are reasonably necessary to implement that + API, Contributor must also include this information in the + legal file. +

      (c) Representations.

      +

      Contributor represents that, except as disclosed pursuant to Section 3.4 + (a) above, Contributor believes that Contributor's Modifications + are Contributor's original creation(s) and/or Contributor has sufficient rights to grant the + rights conveyed by this License. +

      3.5. Required Notices.

      +

      You must duplicate the notice in Exhibit A in each file of the + Source Code. If it is not possible to put such notice in a particular Source Code file due to + its structure, then You must include such notice in a location (such as a relevant directory) + where a user would be likely to look for such a notice. If You created one or more + Modification(s) You may add your name as a Contributor to the notice described in + Exhibit A. You must also duplicate this License in any documentation + for the Source Code where You describe recipients' rights or ownership rights relating to + Covered Code. You may choose to offer, and to charge a fee for, warranty, support, indemnity + or liability obligations to one or more recipients of Covered Code. However, You may do so + only on Your own behalf, and not on behalf of the Initial Developer or any Contributor. You + must make it absolutely clear than any such warranty, support, indemnity or liability + obligation is offered by You alone, and You hereby agree to indemnify the Initial Developer + and every Contributor for any liability incurred by the Initial Developer or such Contributor + as a result of warranty, support, indemnity or liability terms You offer. +

      3.6. Distribution of Executable Versions.

      +

      You may distribute Covered Code in Executable form only if the requirements of Sections + 3.1, 3.2, + 3.3, 3.4 and + 3.5 have been met for that Covered Code, and if You include a + notice stating that the Source Code version of the Covered Code is available under the terms + of this License, including a description of how and where You have fulfilled the obligations + of Section 3.2. The notice must be conspicuously included in any + notice in an Executable version, related documentation or collateral in which You describe + recipients' rights relating to the Covered Code. You may distribute the Executable version of + Covered Code or ownership rights under a license of Your choice, which may contain terms + different from this License, provided that You are in compliance with the terms of this + License and that the license for the Executable version does not attempt to limit or alter the + recipient's rights in the Source Code version from the rights set forth in this License. If + You distribute the Executable version under a different license You must make it absolutely + clear that any terms which differ from this License are offered by You alone, not by the + Initial Developer or any Contributor. You hereby agree to indemnify the Initial Developer and + every Contributor for any liability incurred by the Initial Developer or such Contributor as + a result of any such terms You offer. +

      3.7. Larger Works.

      +

      You may create a Larger Work by combining Covered Code with other code not governed by the + terms of this License and distribute the Larger Work as a single product. In such a case, + You must make sure the requirements of this License are fulfilled for the Covered Code. +

      4. Inability to Comply Due to Statute or Regulation.

      +

      If it is impossible for You to comply with any of the terms of this License with respect to + some or all of the Covered Code due to statute, judicial order, or regulation then You must: + (a) comply with the terms of this License to the maximum extent possible; and (b) describe + the limitations and the code they affect. Such description must be included in the + legal file described in Section + 3.4 and must be included with all distributions of the Source Code. + Except to the extent prohibited by statute or regulation, such description must be + sufficiently detailed for a recipient of ordinary skill to be able to understand it. +

      5. Application of this License.

      +

      This License applies to code to which the Initial Developer has attached the notice in + Exhibit A and to related Covered Code. +

      6. Versions of the License.

      +

      6.1. New Versions

      +

      Netscape Communications Corporation ("Netscape") may publish revised and/or new versions + of the License from time to time. Each version will be given a distinguishing version number. +

      6.2. Effect of New Versions

      +

      Once Covered Code has been published under a particular version of the License, You may + always continue to use it under the terms of that version. You may also choose to use such + Covered Code under the terms of any subsequent version of the License published by Netscape. + No one other than Netscape has the right to modify the terms applicable to Covered Code + created under this License. +

      6.3. Derivative Works

      +

      If You create or use a modified version of this License (which you may only do in order to + apply it to code which is not already Covered Code governed by this License), You must (a) + rename Your license so that the phrases "Mozilla", "MOZILLAPL", "MOZPL", "Netscape", "MPL", + "NPL" or any confusingly similar phrase do not appear in your license (except to note that + your license differs from this License) and (b) otherwise make it clear that Your version of + the license contains terms which differ from the Mozilla Public License and Netscape Public + License. (Filling in the name of the Initial Developer, Original Code or Contributor in the + notice described in Exhibit A shall not of themselves be deemed to + be modifications of this License.) +

      7. Disclaimer of warranty

      +

      Covered code is provided under this license on an "as is" + basis, without warranty of any kind, either expressed or implied, including, without + limitation, warranties that the covered code is free of defects, merchantable, fit for a + particular purpose or non-infringing. The entire risk as to the quality and performance of + the covered code is with you. Should any covered code prove defective in any respect, you + (not the initial developer or any other contributor) assume the cost of any necessary + servicing, repair or correction. This disclaimer of warranty constitutes an essential part + of this license. No use of any covered code is authorized hereunder except under this + disclaimer. +

      8. Termination

      +

      8.1. This License and the rights granted hereunder will terminate + automatically if You fail to comply with terms herein and fail to cure such breach + within 30 days of becoming aware of the breach. All sublicenses to the Covered Code which + are properly granted shall survive any termination of this License. Provisions which, by + their nature, must remain in effect beyond the termination of this License shall survive. +

      8.2. If You initiate litigation by asserting a patent infringement + claim (excluding declatory judgment actions) against Initial Developer or a Contributor + (the Initial Developer or Contributor against whom You file such action is referred to + as "Participant") alleging that: +

        +
      1. such Participant's Contributor Version directly or indirectly + infringes any patent, then any and all rights granted by such Participant to You under + Sections 2.1 and/or 2.2 of this + License shall, upon 60 days notice from Participant terminate prospectively, unless if + within 60 days after receipt of notice You either: (i) agree in writing to pay + Participant a mutually agreeable reasonable royalty for Your past and future use of + Modifications made by such Participant, or (ii) withdraw Your litigation claim with + respect to the Contributor Version against such Participant. If within 60 days of + notice, a reasonable royalty and payment arrangement are not mutually agreed upon in + writing by the parties or the litigation claim is not withdrawn, the rights granted by + Participant to You under Sections 2.1 and/or + 2.2 automatically terminate at the expiration of the 60 day + notice period specified above. +
      2. any software, hardware, or device, other than such Participant's + Contributor Version, directly or indirectly infringes any patent, then any rights + granted to You by such Participant under Sections 2.1(b) + and 2.2(b) are revoked effective as of the date You first + made, used, sold, distributed, or had made, Modifications made by that Participant. +
      +

      8.3. If You assert a patent infringement claim against Participant + alleging that such Participant's Contributor Version directly or indirectly infringes + any patent where such claim is resolved (such as by license or settlement) prior to the + initiation of patent infringement litigation, then the reasonable value of the licenses + granted by such Participant under Sections 2.1 or + 2.2 shall be taken into account in determining the amount or + value of any payment or license. +

      8.4. In the event of termination under Sections + 8.1 or 8.2 above, all end user + license agreements (excluding distributors and resellers) which have been validly + granted by You or any distributor hereunder prior to termination shall survive + termination. +

      9. Limitation of liability

      +

      Under no circumstances and under no legal theory, whether + tort (including negligence), contract, or otherwise, shall you, the initial developer, + any other contributor, or any distributor of covered code, or any supplier of any of + such parties, be liable to any person for any indirect, special, incidental, or + consequential damages of any character including, without limitation, damages for loss + of goodwill, work stoppage, computer failure or malfunction, or any and all other + commercial damages or losses, even if such party shall have been informed of the + possibility of such damages. This limitation of liability shall not apply to liability + for death or personal injury resulting from such party's negligence to the extent + applicable law prohibits such limitation. Some jurisdictions do not allow the exclusion + or limitation of incidental or consequential damages, so this exclusion and limitation + may not apply to you. +

      10. U.S. government end users

      +

      The Covered Code is a "commercial item," as that term is defined in 48 + C.F.R. 2.101 (Oct. 1995), consisting of + "commercial computer software" and "commercial computer software documentation," as such + terms are used in 48 C.F.R. 12.212 (Sept. + 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. + 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users + acquire Covered Code with only those rights set forth herein. +

      11. Miscellaneous

      +

      This License represents the complete agreement concerning subject matter hereof. If + any provision of this License is held to be unenforceable, such provision shall be + reformed only to the extent necessary to make it enforceable. This License shall be + governed by California law provisions (except to the extent applicable law, if any, + provides otherwise), excluding its conflict-of-law provisions. With respect to + disputes in which at least one party is a citizen of, or an entity chartered or + registered to do business in the United States of America, any litigation relating to + this License shall be subject to the jurisdiction of the Federal Courts of the + Northern District of California, with venue lying in Santa Clara County, California, + with the losing party responsible for costs, including without limitation, court + costs and reasonable attorneys' fees and expenses. The application of the United + Nations Convention on Contracts for the International Sale of Goods is expressly + excluded. Any law or regulation which provides that the language of a contract + shall be construed against the drafter shall not apply to this License. +

      12. Responsibility for claims

      +

      As between Initial Developer and the Contributors, each party is responsible for + claims and damages arising, directly or indirectly, out of its utilization of rights + under this License and You agree to work with Initial Developer and Contributors to + distribute such responsibility on an equitable basis. Nothing herein is intended or + shall be deemed to constitute any admission of liability. +

      13. Multiple-licensed code

      +

      Initial Developer may designate portions of the Covered Code as + "Multiple-Licensed". "Multiple-Licensed" means that the Initial Developer permits + you to utilize portions of the Covered Code under Your choice of the MPL + or the alternative licenses, if any, specified by the Initial Developer in the file + described in Exhibit A. +

      Exhibit A - Mozilla Public License.

      +
      "The contents of this file are subject to the Mozilla Public License
      +Version 1.1 (the "License"); you may not use this file except in
      +compliance with the License. You may obtain a copy of the License at
      +http://www.mozilla.org/MPL/
      +
      +Software distributed under the License is distributed on an "AS IS"
      +basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the
      +License for the specific language governing rights and limitations
      +under the License.
      +
      +The Original Code is JTransforms.
      +
      +The Initial Developer of the Original Code is
      +Piotr Wendykier, Emory University.
      +Portions created by the Initial Developer are Copyright (C) 2007-2009
      +the Initial Developer. All Rights Reserved.
      +
      +Alternatively, the contents of this file may be used under the terms of
      +either the GNU General Public License Version 2 or later (the "GPL"), or
      +the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
      +in which case the provisions of the GPL or the LGPL are applicable instead
      +of those above. If you wish to allow use of your version of this file only
      +under the terms of either the GPL or the LGPL, and not to allow others to
      +use your version of this file under the terms of the MPL, indicate your
      +decision by deleting the provisions above and replace them with the notice
      +and other provisions required by the GPL or the LGPL. If you do not delete
      +the provisions above, a recipient may use your version of this file under
      +the terms of any one of the MPL, the GPL or the LGPL.
      +

      NOTE: The text of this Exhibit A may differ slightly from the text of + the notices in the Source Code files of the Original Code. You should + use the text of this Exhibit A rather than the text found in the + Original Code Source Code for Your Modifications. + +

      \ No newline at end of file diff --git a/licenses/LICENSE-kryo.txt b/licenses-binary/LICENSE-kryo.txt similarity index 100% rename from licenses/LICENSE-kryo.txt rename to licenses-binary/LICENSE-kryo.txt diff --git a/licenses-binary/LICENSE-leveldbjni.txt b/licenses-binary/LICENSE-leveldbjni.txt new file mode 100644 index 0000000000000..b4dabb9174c6d --- /dev/null +++ b/licenses-binary/LICENSE-leveldbjni.txt @@ -0,0 +1,27 @@ +Copyright (c) 2011 FuseSource Corp. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of FuseSource Corp. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-machinist.txt b/licenses-binary/LICENSE-machinist.txt new file mode 100644 index 0000000000000..68cc3a3e3a9c4 --- /dev/null +++ b/licenses-binary/LICENSE-machinist.txt @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Erik Osheim, Tom Switzer + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-matchMedia-polyfill.txt b/licenses-binary/LICENSE-matchMedia-polyfill.txt new file mode 100644 index 0000000000000..2fd0bc2b37448 --- /dev/null +++ b/licenses-binary/LICENSE-matchMedia-polyfill.txt @@ -0,0 +1 @@ +matchMedia() polyfill - Test a CSS media type/query in JS. Authors & copyright (c) 2012: Scott Jehl, Paul Irish, Nicholas Zakas. Dual MIT/BSD license \ No newline at end of file diff --git a/licenses/LICENSE-minlog.txt b/licenses-binary/LICENSE-minlog.txt similarity index 100% rename from licenses/LICENSE-minlog.txt rename to licenses-binary/LICENSE-minlog.txt diff --git a/licenses-binary/LICENSE-modernizr.txt b/licenses-binary/LICENSE-modernizr.txt new file mode 100644 index 0000000000000..2bf24b9b9f848 --- /dev/null +++ b/licenses-binary/LICENSE-modernizr.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-netlib.txt b/licenses-binary/LICENSE-netlib.txt similarity index 100% rename from licenses/LICENSE-netlib.txt rename to licenses-binary/LICENSE-netlib.txt diff --git a/licenses/LICENSE-paranamer.txt b/licenses-binary/LICENSE-paranamer.txt similarity index 100% rename from licenses/LICENSE-paranamer.txt rename to licenses-binary/LICENSE-paranamer.txt diff --git a/licenses/LICENSE-jpmml-model.txt b/licenses-binary/LICENSE-pmml-model.txt similarity index 100% rename from licenses/LICENSE-jpmml-model.txt rename to licenses-binary/LICENSE-pmml-model.txt diff --git a/licenses/LICENSE-protobuf.txt b/licenses-binary/LICENSE-protobuf.txt similarity index 100% rename from licenses/LICENSE-protobuf.txt rename to licenses-binary/LICENSE-protobuf.txt diff --git a/licenses-binary/LICENSE-py4j.txt b/licenses-binary/LICENSE-py4j.txt new file mode 100644 index 0000000000000..70af3e69ed67a --- /dev/null +++ b/licenses-binary/LICENSE-py4j.txt @@ -0,0 +1,27 @@ +Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +- The name of the author may not be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + diff --git a/licenses/LICENSE-pyrolite.txt b/licenses-binary/LICENSE-pyrolite.txt similarity index 100% rename from licenses/LICENSE-pyrolite.txt rename to licenses-binary/LICENSE-pyrolite.txt diff --git a/licenses/LICENSE-reflectasm.txt b/licenses-binary/LICENSE-reflectasm.txt similarity index 100% rename from licenses/LICENSE-reflectasm.txt rename to licenses-binary/LICENSE-reflectasm.txt diff --git a/licenses-binary/LICENSE-respond.txt b/licenses-binary/LICENSE-respond.txt new file mode 100644 index 0000000000000..dea4ff9e5b2ea --- /dev/null +++ b/licenses-binary/LICENSE-respond.txt @@ -0,0 +1,22 @@ +Copyright (c) 2012 Scott Jehl + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-sbt-launch-lib.txt b/licenses-binary/LICENSE-sbt-launch-lib.txt new file mode 100644 index 0000000000000..3b9156baaab78 --- /dev/null +++ b/licenses-binary/LICENSE-sbt-launch-lib.txt @@ -0,0 +1,26 @@ +// Generated from http://www.opensource.org/licenses/bsd-license.php +Copyright (c) 2011, Paul Phillips. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the author nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-scala.txt b/licenses-binary/LICENSE-scala.txt similarity index 100% rename from licenses/LICENSE-scala.txt rename to licenses-binary/LICENSE-scala.txt diff --git a/licenses-binary/LICENSE-scopt.txt b/licenses-binary/LICENSE-scopt.txt new file mode 100644 index 0000000000000..e92e9b592fba0 --- /dev/null +++ b/licenses-binary/LICENSE-scopt.txt @@ -0,0 +1,9 @@ +This project is licensed under the MIT license. + +Copyright (c) scopt contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-slf4j.txt b/licenses-binary/LICENSE-slf4j.txt similarity index 100% rename from licenses/LICENSE-slf4j.txt rename to licenses-binary/LICENSE-slf4j.txt diff --git a/licenses-binary/LICENSE-sorttable.js.txt b/licenses-binary/LICENSE-sorttable.js.txt new file mode 100644 index 0000000000000..b31a5b206bf40 --- /dev/null +++ b/licenses-binary/LICENSE-sorttable.js.txt @@ -0,0 +1,16 @@ +Copyright (c) 1997-2007 Stuart Langridge + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/licenses/LICENSE-spire.txt b/licenses-binary/LICENSE-spire.txt similarity index 100% rename from licenses/LICENSE-spire.txt rename to licenses-binary/LICENSE-spire.txt diff --git a/licenses-binary/LICENSE-vis.txt b/licenses-binary/LICENSE-vis.txt new file mode 100644 index 0000000000000..18b7323059a41 --- /dev/null +++ b/licenses-binary/LICENSE-vis.txt @@ -0,0 +1,22 @@ +vis.js +https://github.com/almende/vis + +A dynamic, browser-based visualization library. + +@version 4.16.1 +@date 2016-04-18 + +@license +Copyright (C) 2011-2016 Almende B.V, http://almende.com + +Vis.js is dual licensed under both + +* The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + +* The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. \ No newline at end of file diff --git a/licenses/LICENSE-xmlenc.txt b/licenses-binary/LICENSE-xmlenc.txt similarity index 100% rename from licenses/LICENSE-xmlenc.txt rename to licenses-binary/LICENSE-xmlenc.txt diff --git a/licenses/LICENSE-zstd-jni.txt b/licenses-binary/LICENSE-zstd-jni.txt similarity index 100% rename from licenses/LICENSE-zstd-jni.txt rename to licenses-binary/LICENSE-zstd-jni.txt diff --git a/licenses/LICENSE-zstd.txt b/licenses-binary/LICENSE-zstd.txt similarity index 100% rename from licenses/LICENSE-zstd.txt rename to licenses-binary/LICENSE-zstd.txt diff --git a/licenses/LICENSE-CC0.txt b/licenses/LICENSE-CC0.txt new file mode 100644 index 0000000000000..1625c17936079 --- /dev/null +++ b/licenses/LICENSE-CC0.txt @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/licenses/LICENSE-SnapTree.txt b/licenses/LICENSE-SnapTree.txt deleted file mode 100644 index a538825d89ec5..0000000000000 --- a/licenses/LICENSE-SnapTree.txt +++ /dev/null @@ -1,35 +0,0 @@ -SNAPTREE LICENSE - -Copyright (c) 2009-2012 Stanford University, unless otherwise specified. -All rights reserved. - -This software was developed by the Pervasive Parallelism Laboratory of -Stanford University, California, USA. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of Stanford University nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. diff --git a/licenses/LICENSE-bootstrap.txt b/licenses/LICENSE-bootstrap.txt new file mode 100644 index 0000000000000..6c711832fbc85 --- /dev/null +++ b/licenses/LICENSE-bootstrap.txt @@ -0,0 +1,13 @@ +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/licenses/LICENSE-boto.txt b/licenses/LICENSE-boto.txt deleted file mode 100644 index 7bba0cd9e10a4..0000000000000 --- a/licenses/LICENSE-boto.txt +++ /dev/null @@ -1,20 +0,0 @@ -Copyright (c) 2006-2008 Mitch Garnaat http://garnaat.org/ - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, dis- -tribute, sublicense, and/or sell copies of the Software, and to permit -persons to whom the Software is furnished to do so, subject to the fol- -lowing conditions: - -The above copyright notice and this permission notice shall be included -in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- -ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-datatables.txt b/licenses/LICENSE-datatables.txt new file mode 100644 index 0000000000000..bb7708b5b5a49 --- /dev/null +++ b/licenses/LICENSE-datatables.txt @@ -0,0 +1,7 @@ +Copyright (C) 2008-2018, SpryMedia Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-graphlib-dot.txt b/licenses/LICENSE-graphlib-dot.txt index c9e18cd562423..4864fe05e9803 100644 --- a/licenses/LICENSE-graphlib-dot.txt +++ b/licenses/LICENSE-graphlib-dot.txt @@ -1,4 +1,4 @@ -Copyright (c) 2012-2013 Chris Pettitt +Copyright (c) 2013 Chris Pettitt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/licenses/LICENSE-jbcrypt.txt b/licenses/LICENSE-jbcrypt.txt deleted file mode 100644 index d332534c06356..0000000000000 --- a/licenses/LICENSE-jbcrypt.txt +++ /dev/null @@ -1,17 +0,0 @@ -jBCrypt is subject to the following license: - -/* - * Copyright (c) 2006 Damien Miller - * - * Permission to use, copy, modify, and distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ diff --git a/licenses/LICENSE-join.txt b/licenses/LICENSE-join.txt new file mode 100644 index 0000000000000..1d916090e4ea0 --- /dev/null +++ b/licenses/LICENSE-join.txt @@ -0,0 +1,30 @@ +Copyright (c) 2011, Douban Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-jquery.txt b/licenses/LICENSE-jquery.txt index e1dd696d3b6cc..45930542204fb 100644 --- a/licenses/LICENSE-jquery.txt +++ b/licenses/LICENSE-jquery.txt @@ -1,9 +1,20 @@ -The MIT License (MIT) +Copyright JS Foundation and other contributors, https://js.foundation/ -Copyright (c) +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-json-formatter.txt b/licenses/LICENSE-json-formatter.txt new file mode 100644 index 0000000000000..5193348fce126 --- /dev/null +++ b/licenses/LICENSE-json-formatter.txt @@ -0,0 +1,6 @@ +Copyright 2014 Mohsen Azimi + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/licenses/LICENSE-matchMedia-polyfill.txt b/licenses/LICENSE-matchMedia-polyfill.txt new file mode 100644 index 0000000000000..2fd0bc2b37448 --- /dev/null +++ b/licenses/LICENSE-matchMedia-polyfill.txt @@ -0,0 +1 @@ +matchMedia() polyfill - Test a CSS media type/query in JS. Authors & copyright (c) 2012: Scott Jehl, Paul Irish, Nicholas Zakas. Dual MIT/BSD license \ No newline at end of file diff --git a/licenses/LICENSE-postgresql.txt b/licenses/LICENSE-postgresql.txt deleted file mode 100644 index 515bf9af4d432..0000000000000 --- a/licenses/LICENSE-postgresql.txt +++ /dev/null @@ -1,24 +0,0 @@ -PostgreSQL Database Management System -(formerly known as Postgres, then as Postgres95) - -Portions Copyright (c) 1996-2010, PostgreSQL Global Development Group - -Portions Copyright (c) 1994, The Regents of the University of California - -Permission to use, copy, modify, and distribute this software and its -documentation for any purpose, without fee, and without a written agreement -is hereby granted, provided that the above copyright notice and this -paragraph and the following two paragraphs appear in all copies. - -IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR -DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING -LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS -DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - -THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, -INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY -AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS -ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO -PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. - diff --git a/licenses/LICENSE-respond.txt b/licenses/LICENSE-respond.txt new file mode 100644 index 0000000000000..dea4ff9e5b2ea --- /dev/null +++ b/licenses/LICENSE-respond.txt @@ -0,0 +1,22 @@ +Copyright (c) 2012 Scott Jehl + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-scalacheck.txt b/licenses/LICENSE-scalacheck.txt deleted file mode 100644 index cb8f97842f4c4..0000000000000 --- a/licenses/LICENSE-scalacheck.txt +++ /dev/null @@ -1,32 +0,0 @@ -ScalaCheck LICENSE - -Copyright (c) 2007-2015, Rickard Nilsson -All rights reserved. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of the author nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-vis.txt b/licenses/LICENSE-vis.txt new file mode 100644 index 0000000000000..18b7323059a41 --- /dev/null +++ b/licenses/LICENSE-vis.txt @@ -0,0 +1,22 @@ +vis.js +https://github.com/almende/vis + +A dynamic, browser-based visualization library. + +@version 4.16.1 +@date 2016-04-18 + +@license +Copyright (C) 2011-2016 Almende B.V, http://almende.com + +Vis.js is dual licensed under both + +* The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + +* The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. \ No newline at end of file diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister index 5e5484fd8784d..f14431d50feec 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -1,2 +1,4 @@ org.apache.spark.ml.regression.InternalLinearRegressionModelWriter -org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter \ No newline at end of file +org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter +org.apache.spark.ml.clustering.InternalKMeansModelWriter +org.apache.spark.ml.clustering.PMMLKMeansModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 771cd4fe91dcf..c9786f1f7ceb1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -97,9 +97,11 @@ class DecisionTreeClassifier @Since("1.4.0") ( override def setSeed(value: Long): this.type = set(seed, value) override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { + val instr = Instrumentation.create(this, dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) + instr.logNumClasses(numClasses) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -110,8 +112,8 @@ class DecisionTreeClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(params: _*) + instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) @@ -125,7 +127,8 @@ class DecisionTreeClassifier @Since("1.4.0") ( private[ml] def train(data: RDD[LabeledPoint], oldStrategy: OldStrategy): DecisionTreeClassificationModel = { val instr = Instrumentation.create(this, data) - instr.logParams(params: _*) + instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid)) @@ -279,7 +282,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) val model = new DecisionTreeClassificationModel(metadata.uid, root.asInstanceOf[ClassificationNode], numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c0255103bc313..337133a2e2326 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -146,12 +146,21 @@ class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) + /** @group setParam */ + @Since("2.4.0") + def setValidationIndicatorCol(value: String): this.type = { + set(validationIndicatorCol, value) + } + override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + + val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty + // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. - val oldDataset: RDD[LabeledPoint] = + val convert2LabeledPoint = (dataset: Dataset[_]) => { dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + @@ -159,7 +168,18 @@ class GBTClassifier @Since("1.4.0") ( s" GBTClassifier currently only supports binary classification.") LabeledPoint(label, features) } - val numFeatures = oldDataset.first().features.size + } + + val (trainDataset, validationDataset) = if (withValidation) { + ( + convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))), + convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol)))) + ) + } else { + (convert2LabeledPoint(dataset), null) + } + + val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val numClasses = 2 @@ -169,15 +189,21 @@ class GBTClassifier @Since("1.4.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val instr = Instrumentation.create(this, oldDataset) + val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, + validationIndicatorCol) instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) + val (baseLearners, learnerWeights) = if (withValidation) { + GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } else { + GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) + } + val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m @@ -334,6 +360,21 @@ class GBTClassificationModel private[ml]( // hard coded loss, which is not meant to be changed in the model private val loss = getOldLossType + /** + * Method to compute error or loss for every iteration of gradient boosting. + * + * @param dataset Dataset for validation. + */ + @Since("2.4.0") + def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = { + val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => LabeledPoint(label, features) + } + GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss, + OldAlgo.Classification + ) + } + @Since("2.0.0") override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } @@ -379,14 +420,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 8f950cd28c3aa..38eb04556b775 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -170,7 +170,7 @@ class LinearSVC @Since("2.2.0") ( Instance(label, weight, features) } - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth) @@ -187,6 +187,9 @@ class LinearSVC @Since("2.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } + instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) + instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -209,7 +212,7 @@ class LinearSVC @Since("2.2.0") ( if (numInvalid != 0) { val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " + s"Found $numInvalid invalid labels." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -246,7 +249,7 @@ class LinearSVC @Since("2.2.0") ( bcFeaturesStd.destroy(blocking = false) if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -377,7 +380,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { val Row(coefficients: Vector, intercept: Double) = data.select("coefficients", "intercept").head() val model = new LinearSVCModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index ee4b01058c75c..92e342ed4a464 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -500,7 +500,7 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) @@ -816,7 +816,7 @@ class LogisticRegression @Since("1.2.0") ( if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -1202,6 +1202,11 @@ class LogisticRegressionModel private[spark] ( */ @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + + override def toString: String = { + s"LogisticRegressionModel: " + + s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures" + } } @@ -1270,7 +1275,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { numClasses, isMultinomial) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index af2e4699924e5..57ba47e596a97 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -388,7 +388,7 @@ object MultilayerPerceptronClassificationModel val weights = data.getAs[Vector](1) val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 0293e03d47435..1dde18d2d1a31 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -126,8 +126,10 @@ class NaiveBayes @Since("1.5.0") ( private[spark] def trainWithLabelCheck( dataset: Dataset[_], positiveLabel: Boolean): NaiveBayesModel = { + val instr = Instrumentation.create(this, dataset) if (positiveLabel && isDefined(thresholds)) { val numClasses = getNumClasses(dataset) + instr.logNumClasses(numClasses) require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") @@ -146,7 +148,6 @@ class NaiveBayes @Since("1.5.0") ( } } - val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol, probabilityCol, modelType, smoothing, thresholds) @@ -407,7 +408,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 5348d882cfd67..3474b61e40136 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -289,7 +289,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc) } val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models) - DefaultParamsReader.getAndSetParams(ovrModel, metadata) + metadata.getAndSetParams(ovrModel) ovrModel.set("classifier", classifier) ovrModel } @@ -366,7 +366,7 @@ final class OneVsRest @Since("1.4.0") ( transformSchema(dataset.schema) val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, parallelism) + instr.logParams(labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol) instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName) // determine number of classes either from metadata if provided, or via computation. @@ -383,7 +383,7 @@ final class OneVsRest @Since("1.4.0") ( getClassifier match { case _: HasWeightCol => true case c => - logWarning(s"weightCol is ignored, as it is not supported by $c now.") + instr.logWarning(s"weightCol is ignored, as it is not supported by $c now.") false } } @@ -484,7 +484,7 @@ object OneVsRest extends MLReadable[OneVsRest] { override def load(path: String): OneVsRest = { val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) val ovr = new OneVsRest(metadata.uid) - DefaultParamsReader.getAndSetParams(ovr, metadata) + metadata.getAndSetParams(ovr) ovr.setClassifier(classifier) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bb972e9706fc1..040db3b94b041 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -116,6 +116,7 @@ class RandomForestClassifier @Since("1.4.0") ( set(featureSubsetStrategy, value) override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { + val instr = Instrumentation.create(this, dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) @@ -130,7 +131,6 @@ class RandomForestClassifier @Since("1.4.0") ( val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) - val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) @@ -141,6 +141,8 @@ class RandomForestClassifier @Since("1.4.0") ( val numFeatures = oldDataset.first().features.size val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) + instr.logNumClasses(numClasses) + instr.logNumFeatures(numFeatures) instr.logSuccess(m) m } @@ -319,14 +321,14 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica case (treeMetadata, root) => val tree = new DecisionTreeClassificationModel(treeMetadata.uid, root.asInstanceOf[ClassificationNode], numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f7c422dc0faea..de564471c2b4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -22,17 +22,15 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} -import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -75,7 +73,7 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -113,7 +111,8 @@ class BisectingKMeansModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -132,9 +131,9 @@ class BisectingKMeansModel private[ml] ( */ @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } - parentModel.computeCost(data.map(OldVectors.fromML)) + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) + parentModel.computeCost(data) } @Since("2.0.0") @@ -193,7 +192,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { val dataPath = new Path(path, "data").toString val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) val model = new BisectingKMeansModel(metadata.uid, mllibModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -260,12 +259,11 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { transformSchema(dataset.schema, logging = true) - val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) - val instr = Instrumentation.create(this, rdd) - instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) + val instr = Instrumentation.create(this, dataset) + instr.logParams(featuresCol, predictionCol, k, maxIter, seed, + minDivisibleClusterSize, distanceMeasure) val bkm = new MLlibBisectingKMeans() .setK($(k)) @@ -276,8 +274,9 @@ class BisectingKMeans @Since("2.0.0") ( val parentModel = bkm.run(rdd) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.transform(dataset), $(predictionCol), $(featuresCol), $(k), $(maxIter)) model.setSummary(Some(summary)) + instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) model } @@ -305,6 +304,7 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Since("2.1.0") @Experimental @@ -312,4 +312,5 @@ class BisectingKMeansSummary private[clustering] ( predictions: DataFrame, predictionCol: String, featuresCol: String, - k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) + k: Int, + numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala index 44e832b058b62..7da4c43a1abf3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.sql.{DataFrame, Row} /** @@ -28,13 +28,15 @@ import org.apache.spark.sql.{DataFrame, Row} * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Experimental class ClusteringSummary private[clustering] ( @transient val predictions: DataFrame, val predictionCol: String, val featuresCol: String, - val k: Int) extends Serializable { + val k: Int, + @Since("2.4.0") val numIter: Int) extends Serializable { /** * Cluster centers of the transformed data. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index f19ad7a5a6938..dae64ba9a515d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatr Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -63,7 +63,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT) } @@ -109,8 +109,9 @@ class GaussianMixtureModel private[ml] ( transformSchema(dataset.schema, logging = true) val predUDF = udf((vector: Vector) => predict(vector)) val probUDF = udf((vector: Vector) => predictProbability(vector)) - dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) - .withColumn($(probabilityCol), probUDF(col($(featuresCol)))) + dataset + .withColumn($(predictionCol), predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) + .withColumn($(probabilityCol), probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -233,7 +234,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { } val model = new GaussianMixtureModel(metadata.uid, weights, gaussians) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -340,7 +341,8 @@ class GaussianMixture @Since("2.0.0") ( val sc = dataset.sparkSession.sparkContext val numClusters = $(k) - val instances: RDD[Vector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[Vector] = dataset + .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map { case Row(features: Vector) => features }.cache() @@ -350,7 +352,7 @@ class GaussianMixture @Since("2.0.0") ( s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" + s" matrix is quadratic in the number of features.") - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) instr.logNumFeatures(numFeatures) @@ -421,8 +423,10 @@ class GaussianMixture @Since("2.0.0") ( val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this) val summary = new GaussianMixtureSummary(model.transform(dataset), - $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) + $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iter) model.setSummary(Some(summary)) + instr.logNamedValue("logLikelihood", logLikelihood) + instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) model } @@ -683,6 +687,7 @@ private class ExpectationAggregator( * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. * @param logLikelihood Total log-likelihood for this model on the given data. + * @param numIter Number of iterations. */ @Since("2.0.0") @Experimental @@ -692,8 +697,9 @@ class GaussianMixtureSummary private[clustering] ( @Since("2.0.0") val probabilityCol: String, featuresCol: String, k: Int, - @Since("2.2.0") val logLikelihood: Double) - extends ClusteringSummary(predictions, predictionCol, featuresCol, k) { + @Since("2.2.0") val logLikelihood: Double, + numIter: Int) + extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) { /** * Probability of each cluster. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 987a4285ebad4..f40037a8d9aa9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,12 +17,14 @@ package org.apache.spark.ml.clustering +import scala.collection.mutable + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.{Estimator, Model, PipelineStage} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ @@ -30,8 +32,8 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion @@ -90,7 +92,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -103,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with MLWritable { + private[clustering] val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with GeneralMLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -123,8 +125,11 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) + val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("1.5.0") @@ -144,22 +149,20 @@ class KMeansModel private[ml] ( // TODO: Replace the temp fix when we have proper evaluators defined for clustering. @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) parentModel.computeCost(data) } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[KMeansModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * */ @Since("1.6.0") - override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) private var trainingSummary: Option[KMeansSummary] = None @@ -185,6 +188,47 @@ class KMeansModel private[ml] ( } } +/** Helper class for storing model data */ +private case class ClusterData(clusterIdx: Int, clusterCenter: Vector) + + +/** A writer for KMeans that handles the "internal" (or default) format */ +private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map { + case (center, idx) => + ClusterData(idx, center) + } + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for KMeans that handles the "pmml" format */ +private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + instance.parentModel.toPMML(sc, path) + } +} + + @Since("1.6.0") object KMeansModel extends MLReadable[KMeansModel] { @@ -194,30 +238,12 @@ object KMeansModel extends MLReadable[KMeansModel] { @Since("1.6.0") override def load(path: String): KMeansModel = super.load(path) - /** Helper class for storing model data */ - private case class Data(clusterIdx: Int, clusterCenter: Vector) - /** * We store all cluster centers in a single row and use this class to store model data by * Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility. */ private case class OldData(clusterCenters: Array[OldVector]) - /** [[MLWriter]] instance for [[KMeansModel]] */ - private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: cluster centers - val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => - Data(idx, center) - } - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) - } - } - private class KMeansModelReader extends MLReader[KMeansModel] { /** Checked against metadata when loading model */ @@ -232,14 +258,14 @@ object KMeansModel extends MLReadable[KMeansModel] { val dataPath = new Path(path, "data").toString val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { - val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] + val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier. sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -310,15 +336,13 @@ class KMeans @Since("1.5.0") ( transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val instances = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) if (handlePersistence) { instances.persist(StorageLevel.MEMORY_AND_DISK) } - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, maxIter, seed, tol) val algo = new MLlibKMeans() @@ -332,9 +356,10 @@ class KMeans @Since("1.5.0") ( val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.transform(dataset), $(predictionCol), $(featuresCol), $(k), parentModel.numIter) model.setSummary(Some(summary)) + instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) if (handlePersistence) { instances.unpersist() @@ -363,6 +388,7 @@ object KMeans extends DefaultParamsReadable[KMeans] { * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Since("2.0.0") @Experimental @@ -370,4 +396,5 @@ class KMeansSummary private[clustering] ( predictions: DataFrame, predictionCol: String, featuresCol: String, - k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) + k: Int, + numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 4bab670cc159f..fed42c959b5ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -43,7 +43,7 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, StructType} import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils @@ -345,7 +345,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM s" must be >= 1. Found value: $getTopicConcentration") } } - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } @@ -366,7 +366,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM private object LDAParams { /** - * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] + * Equivalent to [[Metadata.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+. * * @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with @@ -391,7 +391,7 @@ private object LDAParams { s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } case _ => // 2.0+ - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) } } } @@ -461,7 +461,8 @@ abstract class LDAModel private[ml] ( val transformer = oldLocalModel.getTopicDistributionMethod val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() + dataset.withColumn($(topicDistributionCol), + t(DatasetUtils.columnToVector(dataset, getFeaturesCol))).toDF() } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") @@ -568,10 +569,14 @@ abstract class LDAModel private[ml] ( class LocalLDAModel private[ml] ( uid: String, vocabSize: Int, - @Since("1.6.0") override private[clustering] val oldLocalModel: OldLocalLDAModel, + private[clustering] val oldLocalModel_ : OldLocalLDAModel, sparkSession: SparkSession) extends LDAModel(uid, vocabSize, sparkSession) { + override private[clustering] def oldLocalModel: OldLocalLDAModel = { + oldLocalModel_.setSeed(getSeed) + } + @Since("1.6.0") override def copy(extra: ParamMap): LocalLDAModel = { val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession) @@ -938,7 +943,7 @@ object LDA extends MLReadable[LDA] { featuresCol: String): RDD[(Long, OldVector)] = { dataset .withColumn("docId", monotonically_increasing_id()) - .select("docId", featuresCol) + .select(col("docId"), DatasetUtils.columnToVector(dataset, featuresCol)) .rdd .map { case Row(docId: Long, features: Vector) => (docId, OldVectors.fromML(features)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index 2c30a1d9aa947..1b9a3499947d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -18,21 +18,20 @@ package org.apache.spark.ml.clustering import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.Transformer import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types._ /** * Common params for PowerIterationClustering */ private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter - with HasPredictionCol { + with HasWeightCol { /** * The number of clusters to create (k). Must be > 1. Default: 2. @@ -66,62 +65,33 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has def getInitMode: String = $(initMode) /** - * Param for the name of the input column for vertex IDs. - * Default: "id" + * Param for the name of the input column for source vertex IDs. + * Default: "src" * @group param */ @Since("2.4.0") - val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.", + val srcCol = new Param[String](this, "srcCol", "Name of the input column for source vertex IDs.", (value: String) => value.nonEmpty) - setDefault(idCol, "id") - - /** @group getParam */ - @Since("2.4.0") - def getIdCol: String = getOrDefault(idCol) - - /** - * Param for the name of the input column for neighbors in the adjacency list representation. - * Default: "neighbors" - * @group param - */ - @Since("2.4.0") - val neighborsCol = new Param[String](this, "neighborsCol", - "Name of the input column for neighbors in the adjacency list representation.", - (value: String) => value.nonEmpty) - - setDefault(neighborsCol, "neighbors") - /** @group getParam */ @Since("2.4.0") - def getNeighborsCol: String = $(neighborsCol) + def getSrcCol: String = getOrDefault(srcCol) /** - * Param for the name of the input column for neighbors in the adjacency list representation. - * Default: "similarities" + * Name of the input column for destination vertex IDs. + * Default: "dst" * @group param */ @Since("2.4.0") - val similaritiesCol = new Param[String](this, "similaritiesCol", - "Name of the input column for neighbors in the adjacency list representation.", + val dstCol = new Param[String](this, "dstCol", + "Name of the input column for destination vertex IDs.", (value: String) => value.nonEmpty) - setDefault(similaritiesCol, "similarities") - /** @group getParam */ @Since("2.4.0") - def getSimilaritiesCol: String = $(similaritiesCol) + def getDstCol: String = $(dstCol) - protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType)) - SchemaUtils.checkColumnTypes(schema, $(neighborsCol), - Seq(ArrayType(IntegerType, containsNull = false), - ArrayType(LongType, containsNull = false))) - SchemaUtils.checkColumnTypes(schema, $(similaritiesCol), - Seq(ArrayType(FloatType, containsNull = false), - ArrayType(DoubleType, containsNull = false))) - SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - } + setDefault(srcCol -> "src", dstCol -> "dst") } /** @@ -131,21 +101,8 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has * PIC finds a very low-dimensional embedding of a dataset using truncated power * iteration on a normalized pair-wise similarity matrix of the data. * - * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix - * is a symmetric matrix whose entries are non-negative similarities between items. - * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes: - * - `idCol`: vertex ID - * - `neighborsCol`: neighbors of vertex in `idCol` - * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex - * in `idCol` and each neighbor in `neighborsCol` - * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol` - * containing the cluster assignment in `[0,k)` for each row (vertex). - * - * Notes: - * - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation. - * Transform runs the iterative PIC algorithm to cluster the whole input dataset. - * - Input validation: This validates that similarities are non-negative but does NOT validate - * that the input matrix is symmetric. + * This class is not yet an Estimator/Transformer, use `assignClusters` method to run the + * PowerIterationClustering algorithm. * * @see * Spectral clustering (Wikipedia) @@ -154,7 +111,7 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has @Experimental class PowerIterationClustering private[clustering] ( @Since("2.4.0") override val uid: String) - extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { + extends PowerIterationClusteringParams with DefaultParamsWritable { setDefault( k -> 2, @@ -164,10 +121,6 @@ class PowerIterationClustering private[clustering] ( @Since("2.4.0") def this() = this(Identifiable.randomUID("PowerIterationClustering")) - /** @group setParam */ - @Since("2.4.0") - def setPredictionCol(value: String): this.type = set(predictionCol, value) - /** @group setParam */ @Since("2.4.0") def setK(value: Int): this.type = set(k, value) @@ -182,66 +135,56 @@ class PowerIterationClustering private[clustering] ( /** @group setParam */ @Since("2.4.0") - def setIdCol(value: String): this.type = set(idCol, value) + def setSrcCol(value: String): this.type = set(srcCol, value) /** @group setParam */ @Since("2.4.0") - def setNeighborsCol(value: String): this.type = set(neighborsCol, value) + def setDstCol(value: String): this.type = set(dstCol, value) /** @group setParam */ @Since("2.4.0") - def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value) + def setWeightCol(value: String): this.type = set(weightCol, value) + /** + * Run the PIC algorithm and returns a cluster assignment for each input vertex. + * + * @param dataset A dataset with columns src, dst, weight representing the affinity matrix, + * which is the matrix A in the PIC paper. Suppose the src column value is i, + * the dst column value is j, the weight column value is similarity s,,ij,, + * which must be nonnegative. This is a symmetric matrix and hence + * s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be + * either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are + * ignored, because we assume s,,ij,, = 0.0. + * + * @return A dataset that contains columns of vertex id and the corresponding cluster for the id. + * The schema of it will be: + * - id: Long + * - cluster: Int + */ @Since("2.4.0") - override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + def assignClusters(dataset: Dataset[_]): DataFrame = { + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) { + lit(1.0) + } else { + col($(weightCol)).cast(DoubleType) + } - val sparkSession = dataset.sparkSession - val idColValue = $(idCol) - val rdd: RDD[(Long, Long, Double)] = - dataset.select( - col($(idCol)).cast(LongType), - col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)), - col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false)) - ).rdd.flatMap { - case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) => - require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " + - s"equal to the the length of the neighbor similarity list. Row for ID " + - s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " + - s"of length ${sims.length}.") - nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map { - case (nbr, similarity) => (id, nbr, similarity) - } - } + SchemaUtils.checkColumnTypes(dataset.schema, $(srcCol), Seq(IntegerType, LongType)) + SchemaUtils.checkColumnTypes(dataset.schema, $(dstCol), Seq(IntegerType, LongType)) + val rdd: RDD[(Long, Long, Double)] = dataset.select( + col($(srcCol)).cast(LongType), + col($(dstCol)).cast(LongType), + w).rdd.map { + case Row(src: Long, dst: Long, weight: Double) => (src, dst, weight) + } val algorithm = new MLlibPowerIterationClustering() .setK($(k)) .setInitializationMode($(initMode)) .setMaxIterations($(maxIter)) val model = algorithm.run(rdd) - val predictionsRDD: RDD[Row] = model.assignments.map { assignment => - Row(assignment.id, assignment.cluster) - } - - val predictionsSchema = StructType(Seq( - StructField($(idCol), LongType, nullable = false), - StructField($(predictionCol), IntegerType, nullable = false))) - val predictions = { - val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) - dataset.schema($(idCol)).dataType match { - case _: LongType => - uncastPredictions - case otherType => - uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) - } - } - - dataset.join(predictions, $(idCol)) - } - - @Since("2.4.0") - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + import dataset.sparkSession.implicits._ + model.assignments.toDF } @Since("2.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 41eaaf9679914..a906e954fecd5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -238,7 +238,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject val model = new BucketedRandomProjectionLSHModel(metadata.uid, randUnitVectors.rowIter.toArray) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index f49c410cbcfe2..f99649f7fa164 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -217,8 +217,6 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } - - override def write: MLWriter = new Bucketizer.BucketizerWriter(this) } @Since("1.6.0") @@ -296,28 +294,6 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } - - private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // SPARK-23377: The default params will be saved and loaded as user-supplied params. - // Once `inputCols` is set, the default value of `outputCol` param causes the error - // when checking exclusive params. As a temporary to fix it, we skip the default value - // of `outputCol` if `inputCols` is set when saving the metadata. - // TODO: If we modify the persistence mechanism later to better handle default params, - // we can get rid of this. - var paramWithoutOutputCol: Option[JValue] = None - if (instance.isSet(instance.inputCols)) { - val params = instance.extractParamMap().toSeq - val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList - paramWithoutOutputCol = Some(render(jsonParams)) - } - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) - } - } - @Since("1.6.0") override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 16abc4949dea3..dbfb199ccd58f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -334,7 +334,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { val selectedFeatures = data.getAs[Seq[Int]](0).toArray val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) val model = new ChiSqSelectorModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 9e0ed437e7bfc..10c48c3f52085 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -363,7 +363,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { .head() val vocabulary = data.getAs[Seq[String]](0).toArray val model = new CountVectorizerModel(metadata.uid, vocabulary) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 46a0730f5ddb8..58897cca4e5c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -182,7 +182,7 @@ object IDFModel extends MLReadable[IDFModel] { .select("idf") .head() val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf))) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 730ee9fc08db8..1c074e204ad99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -262,7 +262,7 @@ object ImputerModel extends MLReadable[ImputerModel] { val dataPath = new Path(path, "data").toString val surrogateDF = sqlContext.read.parquet(dataPath) val model = new ImputerModel(metadata.uid, surrogateDF) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 85f9732f79f67..90eceb0d61b40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -172,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { .select("maxAbs") .head() val model = new MaxAbsScalerModel(metadata.uid, maxAbs) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 556848e45532d..a043033e96724 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -66,7 +66,7 @@ class MinHashLSHModel private[ml]( val elemsList = elems.toSparse.indices.toList val hashValues = randCoefficients.map { case (a, b) => elemsList.map { elem: Int => - ((1 + elem) * a + b) % MinHashLSH.HASH_PRIME + ((1L + elem) * a + b) % MinHashLSH.HASH_PRIME }.min.toDouble } // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 @@ -205,7 +205,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] { .map(tuple => (tuple(0), tuple(1))).toArray val model = new MinHashLSHModel(metadata.uid, randCoefficients) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index f648deced54cd..2e0ae4af66f06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -243,7 +243,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { .select("originalMin", "originalMax") .head() val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index bd1e3426c8780..4a44f3186538d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -386,7 +386,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { .head() val categorySizes = data.getAs[Seq[Int]](0).toArray val model = new OneHotEncoderModel(metadata.uid, categorySizes) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 4143d864d7930..8172491a517d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -220,7 +220,7 @@ object PCAModel extends MLReadable[PCAModel] { new PCAModel(metadata.uid, pc.asML, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 3b4c25478fb1d..56e2c543d100a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -253,35 +253,11 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) - - override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this) } @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { - private[QuantileDiscretizer] - class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // SPARK-23377: The default params will be saved and loaded as user-supplied params. - // Once `inputCols` is set, the default value of `outputCol` param causes the error - // when checking exclusive params. As a temporary to fix it, we skip the default value - // of `outputCol` if `inputCols` is set when saving the metadata. - // TODO: If we modify the persistence mechanism later to better handle default params, - // we can get rid of this. - var paramWithoutOutputCol: Option[JValue] = None - if (instance.isSet(instance.inputCols)) { - val params = instance.extractParamMap().toSeq - val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList - paramWithoutOutputCol = Some(render(jsonParams)) - } - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) - } - } - @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index e214765e3307f..55e595eee6ffb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -446,7 +446,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -510,7 +510,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { val columnsToPrune = data.getAs[Seq[String]](0).toSet val pruner = new ColumnPruner(metadata.uid, columnsToPrune) - DefaultParamsReader.getAndSetParams(pruner, metadata) + metadata.getAndSetParams(pruner) pruner } } @@ -602,7 +602,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite val prefixesToRewrite = data.getAs[Map[String, String]](1) val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) - DefaultParamsReader.getAndSetParams(rewriter, metadata) + metadata.getAndSetParams(rewriter) rewriter } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 8f125d8fd51d2..91b0707dec3f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -212,7 +212,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { .select("std", "mean") .head() val model = new StandardScalerModel(metadata.uid, std, mean) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 3fcd84c029e61..0f946dd2e015b 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,9 +17,11 @@ package org.apache.spark.ml.feature +import java.util.Locale + import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -84,7 +86,27 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("1.5.0") def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false) + /** + * Locale of the input for case insensitive matching. Ignored when [[caseSensitive]] + * is true. + * Default: Locale.getDefault.toString + * @group param + */ + @Since("2.4.0") + val locale: Param[String] = new Param[String](this, "locale", + "Locale of the input for case insensitive matching. Ignored when caseSensitive is true.", + ParamValidators.inArray[String](Locale.getAvailableLocales.map(_.toString))) + + /** @group setParam */ + @Since("2.4.0") + def setLocale(value: String): this.type = set(locale, value) + + /** @group getParam */ + @Since("2.4.0") + def getLocale: String = $(locale) + + setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), + caseSensitive -> false, locale -> Locale.getDefault.toString) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { @@ -95,8 +117,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String terms.filter(s => !stopWordsSet.contains(s)) } } else { - // TODO: support user locale (SPARK-15064) - val toLower = (s: String) => if (s != null) s.toLowerCase else s + val lc = new Locale($(locale)) + val toLower = (s: String) => if (s != null) s.toLowerCase(lc) else s val lowerStopWords = $(stopWords).map(toLower(_)).toSet udf { terms: Seq[String] => terms.filter(s => !lowerStopWords.contains(toLower(s))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 67cdb097217a2..a833d8b270cf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -315,7 +315,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { .head() val labels = data.getAs[Seq[String]](0).toArray val model = new StringIndexerModel(metadata.uid, labels) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index e6ec4e2e36ff0..0e7396a621dbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -537,7 +537,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { val numFeatures = data.getAs[Int](0) val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index fe3306e1e50d6..fc9996d69ba72 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -410,7 +410,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { } val model = new Word2VecModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 3d041fc80eb7f..d7fbe28ae7a64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -161,6 +161,8 @@ class FPGrowth @Since("2.2.0") ( private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val instr = Instrumentation.create(this, dataset) + instr.logParams(params: _*) val data = dataset.select($(itemsCol)) val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) @@ -183,7 +185,9 @@ class FPGrowth @Since("2.2.0") ( items.unpersist() } - copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + val model = copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + instr.logSuccess(model) + model } @Since("2.2.0") @@ -335,7 +339,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { val dataPath = new Path(path, "data").toString val frequentItems = sparkSession.read.parquet(dataPath) val model = new FPGrowthModel(metadata.uid, frequentItems) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala new file mode 100644 index 0000000000000..bd1c1a8885201 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.fpm + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} + +/** + * :: Experimental :: + * A parallel PrefixSpan algorithm to mine frequent sequential patterns. + * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + * Efficiently by Prefix-Projected Pattern Growth + * (see here). + * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to + * run the PrefixSpan algorithm. + * + * @see Sequential Pattern Mining + * (Wikipedia) + */ +@Since("2.4.0") +@Experimental +final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params { + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("prefixSpan")) + + /** + * Param for the minimal support level (default: `0.1`). + * Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are + * identified as frequent sequential patterns. + * @group param + */ + @Since("2.4.0") + val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset) " + + "times will be output.", ParamValidators.gtEq(0.0)) + + /** @group getParam */ + @Since("2.4.0") + def getMinSupport: Double = $(minSupport) + + /** @group setParam */ + @Since("2.4.0") + def setMinSupport(value: Double): this.type = set(minSupport, value) + + /** + * Param for the maximal pattern length (default: `10`). + * @group param + */ + @Since("2.4.0") + val maxPatternLength = new IntParam(this, "maxPatternLength", + "The maximal length of the sequential pattern.", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxPatternLength: Int = $(maxPatternLength) + + /** @group setParam */ + @Since("2.4.0") + def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value) + + /** + * Param for the maximum number of items (including delimiters used in the internal storage + * format) allowed in a projected database before local processing (default: `32000000`). + * If a projected database exceeds this size, another iteration of distributed prefix growth + * is run. + * @group param + */ + @Since("2.4.0") + val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the internal storage format) " + + "allowed in a projected database before local processing. If a projected database exceeds " + + "this size, another iteration of distributed prefix growth is run.", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize) + + /** @group setParam */ + @Since("2.4.0") + def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) + + /** + * Param for the name of the sequence column in dataset (default "sequence"), rows with + * nulls in this column are ignored. + * @group param + */ + @Since("2.4.0") + val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.") + + /** @group getParam */ + @Since("2.4.0") + def getSequenceCol: String = $(sequenceCol) + + /** @group setParam */ + @Since("2.4.0") + def setSequenceCol(value: String): this.type = set(sequenceCol, value) + + setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000, + sequenceCol -> "sequence") + + /** + * :: Experimental :: + * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + * + * @param dataset A dataset or a dataframe containing a sequence column which is + * {{{ArrayType(ArrayType(T))}}} type, T is the item type for the input dataset. + * @return A `DataFrame` that contains columns of sequence and corresponding frequency. + * The schema of it will be: + * - `sequence: ArrayType(ArrayType(T))` (T is the item type) + * - `freq: Long` + */ + @Since("2.4.0") + def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = { + val sequenceColParam = $(sequenceCol) + val inputType = dataset.schema(sequenceColParam).dataType + require(inputType.isInstanceOf[ArrayType] && + inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], + s"The input column must be ArrayType and the array element type must also be ArrayType, " + + s"but got $inputType.") + + val data = dataset.select(sequenceColParam) + val sequences = data.where(col(sequenceColParam).isNotNull).rdd + .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) + + val mllibPrefixSpan = new mllibPrefixSpan() + .setMinSupport($(minSupport)) + .setMaxPatternLength($(maxPatternLength)) + .setMaxLocalProjDBSize($(maxLocalProjDBSize)) + + val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) + val schema = StructType(Seq( + StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false), + StructField("freq", LongType, nullable = false))) + val freqSequences = dataset.sparkSession.createDataFrame(rows, schema) + + freqSequences + } + + @Since("2.4.0") + override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra) + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala index 8c975a2fba8ca..f1579ec5844a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala @@ -42,9 +42,11 @@ private object RecursiveFlag { val old = Option(hadoopConf.get(flagName)) hadoopConf.set(flagName, value.toString) try f finally { - old match { - case Some(v) => hadoopConf.set(flagName, v) - case None => hadoopConf.unset(flagName) + // avoid false positive of DLS_DEAD_LOCAL_STORE_IN_RETURN by SpotBugs + if (old.isDefined) { + hadoopConf.set(flagName, old.get) + } else { + hadoopConf.unset(flagName) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index 6961b45f55e4d..572b8cf0051b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -61,9 +61,12 @@ private[ml] class IterativelyReweightedLeastSquares( val fitIntercept: Boolean, val regParam: Double, val maxIter: Int, - val tol: Double) extends Logging with Serializable { + val tol: Double) extends Serializable { - def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = { + def fit( + instances: RDD[OffsetInstance], + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[IterativelyReweightedLeastSquares])): IterativelyReweightedLeastSquaresModel = { var converged = false var iter = 0 @@ -83,7 +86,8 @@ private[ml] class IterativelyReweightedLeastSquares( // Estimate new model model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, - standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + standardizeFeatures = false, standardizeLabel = false) + .fit(newInstances, instr = instr) // Check convergence val oldCoefficients = oldModel.coefficients @@ -96,14 +100,14 @@ private[ml] class IterativelyReweightedLeastSquares( if (maxTol < tol) { converged = true - logInfo(s"IRLS converged in $iter iterations.") + instr.logInfo(s"IRLS converged in $iter iterations.") } - logInfo(s"Iteration $iter : relative tolerance = $maxTol") + instr.logInfo(s"Iteration $iter : relative tolerance = $maxTol") iter = iter + 1 if (iter == maxIter) { - logInfo(s"IRLS reached the max number of iterations: $maxIter.") + instr.logInfo(s"IRLS reached the max number of iterations: $maxIter.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index c5c9c8eb2bd29..1b7c15f1f0a8c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -81,13 +81,11 @@ private[ml] class WeightedLeastSquares( val standardizeLabel: Boolean, val solverType: WeightedLeastSquares.Solver = WeightedLeastSquares.Auto, val maxIter: Int = 100, - val tol: Double = 1e-6) extends Logging with Serializable { + val tol: Double = 1e-6 + ) extends Serializable { import WeightedLeastSquares._ require(regParam >= 0.0, s"regParam cannot be negative: $regParam") - if (regParam == 0.0) { - logWarning("regParam is zero, which might cause numerical instability and overfitting.") - } require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0, s"elasticNetParam must be in [0, 1]: $elasticNetParam") require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter") @@ -96,10 +94,17 @@ private[ml] class WeightedLeastSquares( /** * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. */ - def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = { + def fit( + instances: RDD[Instance], + instr: OptionalInstrumentation = OptionalInstrumentation.create(classOf[WeightedLeastSquares]) + ): WeightedLeastSquaresModel = { + if (regParam == 0.0) { + instr.logWarning("regParam is zero, which might cause numerical instability and overfitting.") + } + val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) summary.validate() - logInfo(s"Number of instances: ${summary.count}.") + instr.logInfo(s"Number of instances: ${summary.count}.") val k = if (fitIntercept) summary.k + 1 else summary.k val numFeatures = summary.k val triK = summary.triK @@ -114,11 +119,12 @@ private[ml] class WeightedLeastSquares( if (rawBStd == 0) { if (fitIntercept || rawBBar == 0.0) { if (rawBBar == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } val coefficients = new DenseVector(Array.ofDim(numFeatures)) @@ -128,7 +134,7 @@ private[ml] class WeightedLeastSquares( } else { require(!(regParam > 0.0 && standardizeLabel), "The standard deviation of the label is " + "zero. Model cannot be regularized with standardization=true") - logWarning(s"The standard deviation of the label is zero. Consider setting " + + instr.logWarning(s"The standard deviation of the label is zero. Consider setting " + s"fitIntercept=true.") } } @@ -256,7 +262,7 @@ private[ml] class WeightedLeastSquares( // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to // Quasi-Newton solver. case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto => - logWarning("Cholesky solver failed due to singular covariance matrix. " + + instr.logWarning("Cholesky solver failed due to singular covariance matrix. " + "Retrying with Quasi-Newton solver.") // ab and aa were modified in place, so reconstruct them val _aa = getAtA(aaBarValues, aBarValues) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 9a83a5882ce29..e6c347ed17c15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -865,10 +865,10 @@ trait Params extends Identifiable with Serializable { } /** Internal param map for user-supplied values. */ - private val paramMap: ParamMap = ParamMap.empty + private[ml] val paramMap: ParamMap = ParamMap.empty /** Internal param map for default values. */ - private val defaultParamMap: ParamMap = ParamMap.empty + private[ml] val defaultParamMap: ParamMap = ParamMap.empty /** Validates that the input param belongs to this instance. */ private def shouldOwn(param: Param[_]): Unit = { @@ -905,6 +905,15 @@ trait Params extends Identifiable with Serializable { } } +private[ml] object Params { + /** + * Sets a default param value for a `Params`. + */ + private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = { + params.defaultParamMap.put(param -> value) + } +} + /** * :: DeveloperApi :: * Java-friendly wrapper for [[Params]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index b9c3170cc3c28..7e08675f834da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -95,7 +95,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" + " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"), isValid = "(value: String) => " + - "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)") + "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)"), + ParamDesc[String]("validationIndicatorCol", "name of the column that indicates whether " + + "each row is for training or for validation. False indicates training; true indicates " + + "validation.") ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 282ea6ebcbf7f..5928a0749f738 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -523,4 +523,21 @@ trait HasDistanceMeasure extends Params { /** @group getParam */ final def getDistanceMeasure: String = $(distanceMeasure) } + +/** + * Trait for shared param validationIndicatorCol. This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasValidationIndicatorCol extends Params { + + /** + * Param for name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.. + * @group param + */ + final val validationIndicatorCol: Param[String] = new Param[String](this, "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.") + + /** @group getParam */ + final def getValidationIndicatorCol: String = $(validationIndicatorCol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 81a8f50761e0e..a23f9552b9e5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -529,7 +529,7 @@ object ALSModel extends MLReadable[ALSModel] { val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 4b46c3831d75f..e27a96e1f5dfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -237,7 +237,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { - logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is different from R survival::survreg.") } @@ -423,7 +423,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] .head() val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 5cef5c9f21f1e..8bcf0793a64c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -282,7 +282,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) val model = new DecisionTreeRegressionModel(metadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 834aaa0e362d1..eb8b3c001436a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ /** @@ -145,21 +145,42 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) + /** @group setParam */ + @Since("2.4.0") + def setValidationIndicatorCol(value: String): this.type = { + set(validationIndicatorCol, value) + } + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) - val numFeatures = oldDataset.first().features.size + + val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty + + val (trainDataset, validationDataset) = if (withValidation) { + ( + extractLabeledPoints(dataset.filter(not(col($(validationIndicatorCol))))), + extractLabeledPoints(dataset.filter(col($(validationIndicatorCol)))) + ) + } else { + (extractLabeledPoints(dataset), null) + } + val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val instr = Instrumentation.create(this, oldDataset) + val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) + val (baseLearners, learnerWeights) = if (withValidation) { + GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } else { + GradientBoostedTrees.run(trainDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m @@ -269,6 +290,21 @@ class GBTRegressionModel private[ml]( new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) } + /** + * Method to compute error or loss for every iteration of gradient boosting. + * + * @param dataset Dataset for validation. + * @param loss The loss function used to compute error. Supported options: squared, absolute + */ + @Since("2.4.0") + def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = { + val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => LabeledPoint(label, features) + } + GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, + convertToOldLossType(loss), OldAlgo.Regression) + } + @Since("2.0.0") override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this) } @@ -311,7 +347,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } @@ -319,7 +355,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { s" trees based on metadata but found ${trees.length} trees.") val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 9f1f2405c428e..143c8a3548b1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -404,7 +404,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - val wlsModel = optimizer.fit(instances) + val wlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) @@ -418,10 +418,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val OffsetInstance(label, weight, offset, features) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). - val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam), + instr = OptionalInstrumentation.create(instr)) val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, $(fitIntercept), $(regParam), $(maxIter), $(tol)) - val irlsModel = optimizer.fit(instances) + val irlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) @@ -471,6 +472,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] val epsilon: Double = 1E-16 + private[regression] def ylogy(y: Double, mu: Double): Double = { + if (y == 0) 0.0 else y * math.log(y / mu) + } + /** * Wrapper of family and link combination used in the model. */ @@ -488,7 +493,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine def initialize( instances: RDD[OffsetInstance], fitIntercept: Boolean, - regParam: Double): WeightedLeastSquaresModel = { + regParam: Double, + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[GeneralizedLinearRegression]) + ): WeightedLeastSquaresModel = { val newInstances = instances.map { instance => val mu = family.initialize(instance.label, instance.weight) val eta = predict(mu) - instance.offset @@ -497,7 +505,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine // TODO: Make standardizeFeatures and standardizeLabel configurable. val initialModel = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - .fit(newInstances) + .fit(newInstances, instr) initialModel } @@ -725,10 +733,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu * (1.0 - mu) - private def ylogy(y: Double, mu: Double): Double = { - if (y == 0) 0.0 else y * math.log(y / mu) - } - override def deviance(y: Double, mu: Double, weight: Double): Double = { 2.0 * weight * (ylogy(y, mu) + ylogy(1.0 - y, 1.0 - mu)) } @@ -783,7 +787,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu override def deviance(y: Double, mu: Double, weight: Double): Double = { - 2.0 * weight * (y * math.log(y / mu) - (y - mu)) + 2.0 * weight * (ylogy(y, mu) - (y - mu)) } override def aic( @@ -1146,7 +1150,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 8faab52ea474b..b046897ab2b7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -308,7 +308,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val model = new IsotonicRegressionModel( metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f67d9d831f327..c45ade94a4e33 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -339,7 +339,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = $(elasticNetParam), $(standardization), true, solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol)) - val model = optimizer.fit(instances) + val model = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) // When it is trained by WeightedLeastSquares, training summary does not // attach returned model. val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) @@ -378,6 +378,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val yMean = ySummarizer.mean(0) val rawYStd = math.sqrt(ySummarizer.variance(0)) + + instr.logNumExamples(ySummarizer.count) + instr.logNamedValue(Instrumentation.loggerTags.meanOfLabels, yMean) + instr.logNamedValue(Instrumentation.loggerTags.varianceOfLabels, rawYStd) + if (rawYStd == 0.0) { if ($(fitIntercept) || yMean == 0.0) { // If the rawYStd==0 and fitIntercept==true, then the intercept is yMean with @@ -385,11 +390,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of // the fitIntercept. if (yMean == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } if (handlePersistence) instances.unpersist() @@ -415,7 +421,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") - logWarning(s"The standard deviation of the label is zero. " + + instr.logWarning(s"The standard deviation of the label is zero. " + "Consider setting fitIntercept=true.") } } @@ -430,7 +436,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { - logWarning("Fitting LinearRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting LinearRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is the same as R glmnet but different from LIBSVM.") } @@ -522,7 +528,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -746,7 +752,7 @@ private class InternalLinearRegressionModelWriter /** A writer for LinearRegression that handles the "pmml" format */ private class PMMLLinearRegressionModelWriter - extends MLWriterFormat with MLFormatRegister { + extends MLWriterFormat with MLFormatRegister { override def format(): String = "pmml" @@ -799,7 +805,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { new LinearRegressionModel(metadata.uid, coefficients, intercept, scale) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 7f77398ba2a22..4509f85aafd12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -276,14 +276,14 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 056a94b351f79..905870178e549 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -108,9 +108,11 @@ private[spark] object RandomForest extends Logging { case Some(instrumentation) => instrumentation.logNumFeatures(metadata.numFeatures) instrumentation.logNumClasses(metadata.numClasses) + instrumentation.logNumExamples(metadata.numExamples) case None => logInfo("numFeatures: " + metadata.numFeatures) logInfo("numClasses: " + metadata.numClasses) + logInfo("numExamples: " + metadata.numExamples) } // Find the splits and the corresponding bins (interval between the splits) using a sample diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 81b6222acc7ce..00157fe63af41 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -21,6 +21,7 @@ import java.util.Locale import scala.util.Try +import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -460,18 +461,34 @@ private[ml] trait RandomForestRegressorParams * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { - - /* TODO: Add this doc when we add this param. SPARK-7132 - * Threshold for stopping early when runWithValidation is used. - * If the error rate on the validation input changes by less than the validationTol, - * then learning will stop early (before [[numIterations]]). - * This parameter is ignored when run is used. - * (default = 1e-5) +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize + with HasValidationIndicatorCol { + + /** + * Threshold for stopping early when fit with validation is used. + * (This parameter is ignored when fit without validation is used.) + * The decision to stop early is decided based on this logic: + * If the current loss on the validation set is greater than 0.01, the diff + * of validation error is compared to relative tolerance which is + * validationTol * (current loss on the validation set). + * If the current loss on the validation set is less than or equal to 0.01, + * the diff of validation error is compared to absolute tolerance which is + * validationTol * 0.01. * @group param + * @see validationIndicatorCol */ - // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") - // validationTol -> 1e-5 + @Since("2.4.0") + final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", + "Threshold for stopping early when fit with validation is used." + + "If the error rate on the validation input changes by less than the validationTol," + + "then learning will stop early (before `maxIter`)." + + "This parameter is ignored when fit without validation is used.", + ParamValidators.gtEq(0.0) + ) + + /** @group getParam */ + @Since("2.4.0") + final def getValidationTol: Double = $(validationTol) /** * @deprecated This method is deprecated and will be removed in 3.0.0. @@ -497,7 +514,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setStepSize(value: Double): this.type = set(stepSize, value) - setDefault(maxIter -> 20, stepSize -> 0.1) + setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01) setDefault(featureSubsetStrategy -> "all") @@ -507,7 +524,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) // NOTE: The old API does not support "seed" so we ignore it. - new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) + new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize, getValidationTol) } /** Get old Gradient Boosting Loss type */ @@ -579,7 +596,11 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { - getLossType match { + convertToOldLossType(getLossType) + } + + private[ml] def convertToOldLossType(loss: String): OldLoss = { + loss match { case "squared" => OldSquaredError case "absolute" => OldAbsoluteError case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index c2826dcc08634..f327f37bad204 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -144,7 +144,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() val validationDataset = sparkSession.createDataFrame(validation, schema).cache() - logDebug(s"Train split $splitIndex with multiple sets of parameters.") + instr.logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => @@ -155,7 +155,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) - logDebug(s"Got metric $metric for model trained with $paramMap.") + instr.logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } @@ -169,12 +169,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) foldMetrics }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits - logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") + instr.logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best cross-validation metric: $bestMetric.") + instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + instr.logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) copyValues(new CrossValidatorModel(uid, bestModel, metrics) @@ -234,8 +234,7 @@ object CrossValidator extends MLReadable[CrossValidator] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(cv, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(cv, skipParams = Option(List("estimatorParamMaps"))) cv } } @@ -424,8 +423,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 8d1b9a8ddab59..14d6a69c36747 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -143,7 +143,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } else None // Fit models in a Future for training in parallel - logDebug(s"Train split with multiple sets of parameters.") + instr.logDebug(s"Train split with multiple sets of parameters.") val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] @@ -153,7 +153,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) - logDebug(s"Got metric $metric for model trained with $paramMap.") + instr.logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } @@ -165,12 +165,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St trainingDataset.unpersist() validationDataset.unpersist() - logInfo(s"Train validation split metrics: ${metrics.toSeq}") + instr.logInfo(s"Train validation split metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best train validation split metric: $bestMetric.") + instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + instr.logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) @@ -228,8 +228,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(tvs, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(tvs, skipParams = Option(List("estimatorParamMaps"))) tvs } } @@ -407,8 +406,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala new file mode 100644 index 0000000000000..6af4b3ebc2cc2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Column, Dataset, Row} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} + + +private[spark] object DatasetUtils { + + /** + * Cast a column in a Dataset to Vector type. + * + * The supported data types of the input column are + * - Vector + * - float/double type Array. + * + * Note: The returned column does not have Metadata. + * + * @param dataset input DataFrame + * @param colName column name. + * @return Vector column + */ + def columnToVector(dataset: Dataset[_], colName: String): Column = { + val columnDataType = dataset.schema(colName).dataType + columnDataType match { + case _: VectorUDT => col(colName) + case fdt: ArrayType => + val transferUDF = fdt.elementType match { + case _: FloatType => udf(f = (vector: Seq[Float]) => { + val inputArray = Array.fill[Double](vector.size)(0.0) + vector.indices.foreach(idx => inputArray(idx) = vector(idx).toDouble) + Vectors.dense(inputArray) + }) + case _: DoubleType => udf((vector: Seq[Double]) => { + Vectors.dense(vector.toArray) + }) + case other => + throw new IllegalArgumentException(s"Array[$other] column cannot be cast to Vector") + } + transferUDF(col(colName)) + case other => + throw new IllegalArgumentException(s"$other column cannot be cast to Vector") + } + } + + def columnToOldVector(dataset: Dataset[_], colName: String): RDD[OldVector] = { + dataset.select(columnToVector(dataset, colName)) + .rdd.map { + case Row(point: Vector) => OldVectors.fromML(point) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index e694bc27b2f1e..11f46eb9e4359 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.util import java.util.UUID +import scala.reflect.ClassTag + import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -28,6 +30,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.Param import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.util.Utils /** * A small wrapper that defines a training session for an estimator, and some methods to log @@ -40,11 +43,14 @@ import org.apache.spark.sql.Dataset * @tparam E the type of the estimator */ private[spark] class Instrumentation[E <: Estimator[_]] private ( - estimator: E, dataset: RDD[_]) extends Logging { + val estimator: E, + val dataset: RDD[_]) extends Logging { private val id = UUID.randomUUID() private val prefix = { - val className = estimator.getClass.getSimpleName + // estimator.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + val className = Utils.getSimpleName(estimator.getClass) s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " } @@ -55,6 +61,13 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( s" storageLevel=${dataset.getStorageLevel}") } + /** + * Logs a debug message with a prefix that uniquely identifies the training session. + */ + override def logDebug(msg: => String): Unit = { + super.logDebug(prefix + msg) + } + /** * Logs a warning message with a prefix that uniquely identifies the training session. */ @@ -103,6 +116,10 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( logNamedValue(Instrumentation.loggerTags.numClasses, num) } + def logNumExamples(num: Long): Unit = { + logNamedValue(Instrumentation.loggerTags.numExamples, num) + } + /** * Logs the value with customized name field. */ @@ -114,6 +131,23 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( log(compact(render(name -> value))) } + def logNamedValue(name: String, value: Double): Unit = { + log(compact(render(name -> value))) + } + + def logNamedValue(name: String, value: Array[String]): Unit = { + log(compact(render(name -> compact(render(value.toSeq))))) + } + + def logNamedValue(name: String, value: Array[Long]): Unit = { + log(compact(render(name -> compact(render(value.toSeq))))) + } + + def logNamedValue(name: String, value: Array[Double]): Unit = { + log(compact(render(name -> compact(render(value.toSeq))))) + } + + /** * Logs the successful completion of the training session. */ @@ -131,6 +165,8 @@ private[spark] object Instrumentation { val numFeatures = "numFeatures" val numClasses = "numClasses" val numExamples = "numExamples" + val meanOfLabels = "meanOfLabels" + val varianceOfLabels = "varianceOfLabels" } /** @@ -150,3 +186,56 @@ private[spark] object Instrumentation { } } + +/** + * A small wrapper that contains an optional `Instrumentation` object. + * Provide some log methods, if the containing `Instrumentation` object is defined, + * will log via it, otherwise will log via common logger. + */ +private[spark] class OptionalInstrumentation private( + val instrumentation: Option[Instrumentation[_ <: Estimator[_]]], + val className: String) extends Logging { + + protected override def logName: String = className + + override def logInfo(msg: => String) { + instrumentation match { + case Some(instr) => instr.logInfo(msg) + case None => super.logInfo(msg) + } + } + + override def logWarning(msg: => String) { + instrumentation match { + case Some(instr) => instr.logWarning(msg) + case None => super.logWarning(msg) + } + } + + override def logError(msg: => String) { + instrumentation match { + case Some(instr) => instr.logError(msg) + case None => super.logError(msg) + } + } +} + +private[spark] object OptionalInstrumentation { + + /** + * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. + */ + def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = { + new OptionalInstrumentation(Some(instr), + instr.estimator.getClass.getName.stripSuffix("$")) + } + + /** + * Creates an `OptionalInstrumentation` object from a `Class` object. + * The created `OptionalInstrumentation` object will log messages via common logger and use the + * specified class name as logger name. + */ + def create(clazz: Class[_]): OptionalInstrumentation = { + new OptionalInstrumentation(None, clazz.getName.stripSuffix("$")) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7edcd498678cc..72a60e04360d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -39,7 +39,7 @@ import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.util.{Utils, VersionUtils} /** * Trait for `MLWriter` and `MLReader`. @@ -421,6 +421,7 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid + * - defaultParamMap * - paramMap * - (optionally, extra metadata) * @@ -453,15 +454,20 @@ private[ml] object DefaultParamsWriter { paramMap: Option[JValue] = None): String = { val uid = instance.uid val cls = instance.getClass.getName - val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val params = instance.paramMap.toSeq + val defaultParams = instance.defaultParamMap.toSeq val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) + val jsonDefaultParams = render(defaultParams.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList) val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("defaultParamMap" -> jsonDefaultParams) val metadata = extraMetadata match { case Some(jObject) => basicMetadata ~ jObject @@ -488,7 +494,7 @@ private[ml] class DefaultParamsReader[T] extends MLReader[T] { val cls = Utils.classForName(metadata.className) val instance = cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] - DefaultParamsReader.getAndSetParams(instance, metadata) + metadata.getAndSetParams(instance) instance.asInstanceOf[T] } } @@ -499,6 +505,8 @@ private[ml] object DefaultParamsReader { * All info from metadata file. * * @param params paramMap, as a `JValue` + * @param defaultParams defaultParamMap, as a `JValue`. For metadata file prior to Spark 2.4, + * this is `JNothing`. * @param metadata All metadata, including the other fields * @param metadataJson Full metadata file String (for debugging) */ @@ -508,27 +516,90 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, + defaultParams: JValue, metadata: JValue, metadataJson: String) { + + private def getValueFromParams(params: JValue): Seq[(String, JValue)] = { + params match { + case JObject(pairs) => pairs + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: $metadataJson.") + } + } + /** * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. * This can be useful for getting a Param value before an instance of `Params` - * is available. + * is available. This will look up `params` first, if not existing then looking up + * `defaultParams`. */ def getParamValue(paramName: String): JValue = { implicit val format = DefaultFormats - params match { + + // Looking up for `params` first. + var pairs = getValueFromParams(params) + var foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + if (foundPairs.length == 0) { + // Looking up for `defaultParams` then. + pairs = getValueFromParams(defaultParams) + foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + } + assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" + + s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) + + foundPairs.map(_._2).head + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params (except params included by `skipParams` list) implement + * [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * + * @param skipParams The params included in `skipParams` won't be set. This is useful if some + * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] + * and need special handling. + */ + def getAndSetParams( + instance: Params, + skipParams: Option[List[String]] = None): Unit = { + setParams(instance, skipParams, isDefault = false) + + // For metadata file prior to Spark 2.4, there is no default section. + val (major, minor) = VersionUtils.majorMinorVersion(sparkVersion) + if (major > 2 || (major == 2 && minor >= 4)) { + setParams(instance, skipParams, isDefault = true) + } + } + + private def setParams( + instance: Params, + skipParams: Option[List[String]], + isDefault: Boolean): Unit = { + implicit val format = DefaultFormats + val paramsToSet = if (isDefault) defaultParams else params + paramsToSet match { case JObject(pairs) => - val values = pairs.filter { case (pName, jsonValue) => - pName == paramName - }.map(_._2) - assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + - s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) - values.head + pairs.foreach { case (paramName, jsonValue) => + if (skipParams == None || !skipParams.get.contains(paramName)) { + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + if (isDefault) { + Params.setDefault(instance, param, value) + } else { + instance.set(param, value) + } + } + } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: $metadataJson.") + s"Cannot recognize JSON metadata: ${metadataJson}.") } } } @@ -561,43 +632,14 @@ private[ml] object DefaultParamsReader { val uid = (metadata \ "uid").extract[String] val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] + val defaultParams = metadata \ "defaultParamMap" val params = metadata \ "paramMap" if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) - } - - /** - * Extract Params from metadata, and set them in the instance. - * This works if all Params (except params included by `skipParams` list) implement - * [[org.apache.spark.ml.param.Param.jsonDecode()]]. - * - * @param skipParams The params included in `skipParams` won't be set. This is useful if some - * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] - * and need special handling. - * TODO: Move to [[Metadata]] method - */ - def getAndSetParams( - instance: Params, - metadata: Metadata, - skipParams: Option[List[String]] = None): Unit = { - implicit val format = DefaultFormats - metadata.params match { - case JObject(pairs) => - pairs.foreach { case (paramName, jsonValue) => - if (skipParams == None || !skipParams.get.contains(paramName)) { - val param = instance.getParam(paramName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) - } - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") - } + Metadata(className, uid, timestamp, sparkVersion, params, defaultParams, metadata, metadataStr) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 334410c9620de..d9a3f85ef9a24 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,7 +17,8 @@ package org.apache.spark.ml.util -import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType} +import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.sql.types._ /** @@ -101,4 +102,17 @@ private[spark] object SchemaUtils { require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.") StructType(schema.fields :+ col) } + + /** + * Check whether the given column in the schema is one of the supporting vector type: Vector, + * Array[Float]. Array[Double] + * @param schema input schema + * @param colName column name + */ + def validateVectorCompatibleColumn(schema: StructType, colName: String): Unit = { + val typeCandidates = List( new VectorUDT, + new ArrayType(DoubleType, false), + new ArrayType(FloatType, false)) + checkColumnTypes(schema, colName, typeCandidates) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b32d3f252ae59..db3f074ecfbac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -572,10 +572,7 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[java.lang.Iterable[Any]], minSupport: Double, numPartitions: Int): FPGrowthModel[Any] = { - val fpg = new FPGrowth() - .setMinSupport(minSupport) - .setNumPartitions(numPartitions) - + val fpg = new FPGrowth(minSupport, numPartitions) val model = fpg.run(data.rdd.map(_.asScala.toArray)) new FPGrowthModelWrapper(model) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index b5b1be3490497..37ae8b1a6171a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -348,7 +348,7 @@ class KMeans private ( logInfo(s"The cost is $cost.") - new KMeansModel(centers.map(_.vector), distanceMeasure) + new KMeansModel(centers.map(_.vector), distanceMeasure, iteration) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a78c21e838e44..e3a88b42fbf73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -36,8 +36,9 @@ import org.apache.spark.sql.{Row, SparkSession} * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ @Since("0.8.0") -class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vector], - @Since("2.4.0") val distanceMeasure: String) +class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], + @Since("2.4.0") val distanceMeasure: String, + private[spark] val numIter: Int) extends Saveable with Serializable with PMMLExportable { private val distanceMeasureInstance: DistanceMeasure = @@ -46,6 +47,10 @@ class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vec private val clusterCentersWithNorm = if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + @Since("2.4.0") + private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) = + this(clusterCenters: Array[Vector], distanceMeasure, -1) + @Since("1.1.0") def this(clusterCenters: Array[Vector]) = this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b8a6e94248421..f915062d77389 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, Utils} /** * Latent Dirichlet Allocation (LDA) model. @@ -194,6 +194,8 @@ class LocalLDAModel private[spark] ( override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { + private var seed: Long = Utils.random.nextLong() + @Since("1.3.0") override def k: Int = topics.numCols @@ -216,6 +218,21 @@ class LocalLDAModel private[spark] ( override protected def formatVersion = "1.0" + /** + * Random seed for cluster initialization. + */ + @Since("2.4.0") + def getSeed: Long = seed + + /** + * Set the random seed for cluster initialization. + */ + @Since("2.4.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, @@ -298,6 +315,7 @@ class LocalLDAModel private[spark] ( // by topic (columns of lambda) val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta) + val gammaSeed = this.seed // Sum bound components for each document: // component for prob(tokens) + component for prob(document-topic distribution) @@ -306,7 +324,7 @@ class LocalLDAModel private[spark] ( val localElogbeta = ElogbetaBc.value var docBound = 0.0D val (gammad: BDV[Double], _, _) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, exp(localElogbeta), brzAlpha, gammaShape, k) + termCounts, exp(localElogbeta), brzAlpha, gammaShape, k, gammaSeed + id) val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) // E[log p(doc | theta, beta)] @@ -352,6 +370,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed documents.map { case (id: Long, termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -362,7 +381,8 @@ class LocalLDAModel private[spark] ( expElogbetaBc.value, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed + id) (id, Vectors.dense(normalize(gamma, 1.0).toArray)) } } @@ -376,6 +396,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed (termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -386,7 +407,8 @@ class LocalLDAModel private[spark] ( expElogbeta, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } @@ -403,6 +425,7 @@ class LocalLDAModel private[spark] ( */ @Since("2.0.0") def topicDistribution(document: Vector): Vector = { + val gammaSeed = this.seed val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) if (document.numNonzeros == 0) { Vectors.zeros(this.k) @@ -412,7 +435,8 @@ class LocalLDAModel private[spark] ( expElogbeta, this.docConcentration.asBreeze, gammaShape, - this.k) + this.k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 693a2a31f026b..f8e5f3ed76457 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -464,6 +465,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape val optimizeDocConcentration = this.optimizeDocConcentration + val seed = randomGenerator.nextLong() // If and only if optimizeDocConcentration is set true, // we calculate logphat in the same pass as other statistics. // No calculation of loghat happens otherwise. @@ -473,20 +475,21 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { None } - val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => - val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) - - val stat = BDM.zeros[Double](k, vocabSize) - val logphatPartOption = logphatPartOptionBase() - var nonEmptyDocCount: Long = 0L - nonEmptyDocs.foreach { case (_, termCounts: Vector) => - nonEmptyDocCount += 1 - val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids) + sstats - logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) - } - Iterator((stat, logphatPartOption, nonEmptyDocCount)) + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitionsWithIndex { + (index, docs) => + val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) + + val stat = BDM.zeros[Double](k, vocabSize) + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount: Long = 0L + nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 + val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, expElogbetaBc.value, alpha, gammaShape, k, seed + index) + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) + } + Iterator((stat, logphatPartOption, nonEmptyDocCount)) } val elementWiseSum = ( @@ -578,7 +581,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { } override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { - new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape) + new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta) + .setSeed(randomGenerator.nextLong()) } } @@ -605,18 +609,20 @@ private[clustering] object OnlineLDAOptimizer { expElogbeta: BDM[Double], alpha: breeze.linalg.Vector[Double], gammaShape: Double, - k: Int): (BDV[Double], BDM[Double], List[Int]) = { + k: Int, + seed: Long): (BDV[Double], BDM[Double], List[Int]) = { val (ids: List[Int], cts: Array[Double]) = termCounts match { case v: DenseVector => ((0 until v.size).toList, v.values) case v: SparseVector => (v.indices.toList, v.values) } // Initialize the variational distribution q(theta|gamma) for the mini-batch + val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(seed)) val gammad: BDV[Double] = - new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + new Gamma(gammaShape, 1.0 / gammaShape)(randBasis).samplesVector(k) // K val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K - val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids + val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids var meanGammaChange = 1D val ctsVector = new BDV[Double](cts) // ids diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 3ca75e8cdb97a..7a5e520d5818e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.random.XORShiftRandom * $$ * \begin{align} * c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\ - * n_t+t &= n_t * a + m_t + * n_t+1 &= n_t * a + m_t * \end{align} * $$ * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index f6b1143272d16..4f2b7e6f0764e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -162,7 +162,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { * */ @Since("1.3.0") -class FPGrowth private ( +class FPGrowth private[spark] ( private var minSupport: Double, private var numPartitions: Int) extends Logging with Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 3f8d65a378e2c..7aed2f3bd8a61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -49,8 +49,7 @@ import org.apache.spark.storage.StorageLevel * * @param minSupport the minimal support level of the sequential pattern, any pattern that appears * more than (minSupport * size-of-the-dataset) times will be output - * @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears - * less than maxPatternLength will be output + * @param maxPatternLength the maximal length of the sequential pattern * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal * storage format) allowed in a projected database before local * processing. If a projected database exceeds this size, another diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index f0ee5496f9d1d..e6d2a8e2b900e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.RegressionLeafNode -import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -34,6 +34,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit import org.apache.spark.util.Utils /** @@ -365,6 +366,78 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { assert(mostImportantFeature !== mostIF) } + test("model evaluateEachIteration") { + val gbt = new GBTClassifier() + .setSeed(1L) + .setMaxDepth(2) + .setMaxIter(3) + .setLossType("logistic") + val model3 = gbt.fit(trainData.toDF) + val model1 = new GBTClassificationModel("gbt-cls-model-test1", + model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures, model3.numClasses) + val model2 = new GBTClassificationModel("gbt-cls-model-test2", + model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses) + + val evalArr = model3.evaluateEachIteration(validationData.toDF) + val remappedValidationData = validationData.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData, + model1.trees, model1.treeWeights, model1.getOldLossType) + val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData, + model2.trees, model2.treeWeights, model2.getOldLossType) + val lossErr3 = GradientBoostedTrees.computeError(remappedValidationData, + model3.trees, model3.treeWeights, model3.getOldLossType) + + assert(evalArr(0) ~== lossErr1 relTol 1E-3) + assert(evalArr(1) ~== lossErr2 relTol 1E-3) + assert(evalArr(2) ~== lossErr3 relTol 1E-3) + } + + test("runWithValidation stops early and performs better on a validation dataset") { + val validationIndicatorCol = "validationIndicator" + val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false)) + val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true)) + + val numIter = 20 + for (lossType <- GBTClassifier.supportedLossTypes) { + val gbt = new GBTClassifier() + .setSeed(123) + .setMaxDepth(2) + .setLossType(lossType) + .setMaxIter(numIter) + val modelWithoutValidation = gbt.fit(trainDF) + + gbt.setValidationIndicatorCol(validationIndicatorCol) + val modelWithValidation = gbt.fit(trainDF.union(validationDF)) + + assert(modelWithoutValidation.numTrees === numIter) + // early stop + assert(modelWithValidation.numTrees < numIter) + + val (errorWithoutValidation, errorWithValidation) = { + val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (GradientBoostedTrees.computeError(remappedRdd, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType), + GradientBoostedTrees.computeError(remappedRdd, modelWithValidation.trees, + modelWithValidation.treeWeights, modelWithValidation.getOldLossType)) + } + assert(errorWithValidation < errorWithoutValidation) + + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validationData, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, + OldAlgo.Classification) + assert(evaluationArray.length === numIter) + assert(evaluationArray(modelWithValidation.numTrees) > + evaluationArray(modelWithValidation.numTrees - 1)) + var i = 1 + while (i < modelWithValidation.numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 36b7e51f93d01..75c2aeb146786 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2751,6 +2751,12 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model.getFamily === family) } } + + test("toString") { + val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0) + val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures = 3" + assert(model.toString === expected) + } } object LogisticRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 02880f96ae6d9..1b7780e171e77 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.ml.clustering -import org.apache.spark.{SparkException, SparkFunSuite} +import scala.language.existentials + +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.clustering.DistanceMeasure -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset -class BisectingKMeansSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + +class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ final val k = 5 @transient var dataset: Dataset[_] = _ @@ -65,10 +69,13 @@ class BisectingKMeansSuite // Verify fit does not fail on very sparse data val model = bkm.fit(sparseDataset) - val result = model.transform(sparseDataset) - val numClusters = result.select("prediction").distinct().collect().length - // Verify we hit the edge case - assert(numClusters < k && numClusters > 1) + + testTransformerByGlobalCheckFunc[Tuple1[Vector]](sparseDataset.toDF(), model, "prediction") { + rows => + val numClusters = rows.distinct.length + // Verify we hit the edge case + assert(numClusters < k && numClusters > 1) + } } test("setter/getter") { @@ -101,19 +108,16 @@ class BisectingKMeansSuite val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = bkm.fit(dataset) assert(model.clusterCenters.length === k) - - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName) { rows => + val clusters = rows.map(_.getAs[Int](predictionColName)).toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + } + // Check validity of model summary val numRows = dataset.count() assert(model.hasSummary) @@ -129,6 +133,7 @@ class BisectingKMeansSuite assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 20) model.setSummary(None) assert(!model.hasSummary) @@ -182,6 +187,22 @@ class BisectingKMeansSuite model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) } + + test("BisectingKMeans with Array input") { + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new BisectingKMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) + } } object BisectingKMeansSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 08b800b7e4183..13bed9dbe3e89 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -17,21 +17,21 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.stat.distribution.MultivariateGaussian -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { - import testImplicits._ import GaussianMixtureSuite._ + import testImplicits._ final val k = 5 private val seed = 538009335 @@ -118,15 +118,10 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.weights.length === k) assert(model.gaussians.length === k) - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName, probabilityColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - // Check prediction matches the highest probability, and probabilities sum to one. - transformed.select(predictionColName, probabilityColName).collect().foreach { - case Row(pred: Int, prob: Vector) => + testTransformer[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName, probabilityColName) { + case Row(_, pred: Int, prob: Vector) => val probArray = prob.toArray val predFromProb = probArray.zipWithIndex.maxBy(_._1)._2 assert(pred === predFromProb) @@ -150,6 +145,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 2) model.setSummary(None) assert(!model.hasSummary) @@ -256,6 +252,22 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4, triangularValues) assert(symmetricMatrix === expectedMatrix) } + + test("GaussianMixture with Array input") { + def trainAndComputlogLikelihood(dataset: Dataset[_]): Double = { + val model = new GaussianMixture().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.summary.logLikelihood + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueLikelihood = trainAndComputlogLikelihood(newDataset) + val doubleLikelihood = trainAndComputlogLikelihood(newDatasetD) + val floatLikelihood = trainAndComputlogLikelihood(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueLikelihood ~== doubleLikelihood absTol 1e-6) + assert(trueLikelihood ~== floatLikelihood absTol 1e-6) + } } object GaussianMixtureSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 32830b39407ad..829c90fe34e94 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -17,19 +17,26 @@ package org.apache.spark.ml.clustering +import scala.language.existentials import scala.util.Random -import org.apache.spark.{SparkException, SparkFunSuite} +import org.dmg.pmml.{ClusteringModel, PMML} + +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils, PMMLReadWriteTest} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, + KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { + + import testImplicits._ final val k = 5 @transient var dataset: Dataset[_] = _ @@ -103,15 +110,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val model = kmeans.fit(dataset) assert(model.clusterCenters.length === k) - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName) { rows => + val clusters = rows.map(_.getAs[Int](predictionColName)).toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) + assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) @@ -130,6 +135,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 1) model.setSummary(None) assert(!model.hasSummary) @@ -143,9 +149,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName) val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName)) - Seq(featuresColName, predictionColName).foreach { column => - assert(transformed.columns.contains(column)) - } + assert(transformed.schema.fieldNames.toSet === Set(featuresColName, predictionColName)) assert(model.getFeaturesCol == featuresColName) assert(model.getPredictionCol == predictionColName) } @@ -194,6 +198,23 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(e.getCause.getMessage.contains("Cosine distance is not defined")) } + test("KMean with Array input") { + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new KMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) + } + + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) @@ -202,6 +223,27 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, KMeansSuite.allParamSettings, checkModelData) } + + test("pmml export") { + val clusterCenters = Array( + MLlibVectors.dense(1.0, 2.0, 6.0), + MLlibVectors.dense(1.0, 3.0, 0.0), + MLlibVectors.dense(1.0, 4.0, 6.0)) + val oldKmeansModel = new MLlibKMeansModel(clusterCenters) + val kmeansModel = new KMeansModel("", oldKmeansModel) + def checkModel(pmml: PMML): Unit = { + // Check the header descripiton is what we expect + assert(pmml.getHeader.getDescription === "k-means clustering") + // check that the number of fields match the single vector size + assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) + // This verify that there is a model attached to the pmml object and the model is a clustering + // one. It also verifies that the pmml model has the same number of clusters of the spark + // model. + val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] + assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) + } + testPMMLWrite(sc, kmeansModel, checkModel) + } } object KMeansSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index e73bbc18d76bd..db92132d18b7b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.hadoop.fs.Path -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ - object LDASuite { def generateLDAData( spark: SparkSession, @@ -35,9 +34,8 @@ object LDASuite { vocabSize: Int): DataFrame = { val avgWC = 1 // average instances of each word in a doc val sc = spark.sparkContext - val rng = new java.util.Random() - rng.setSeed(1) val rdd = sc.parallelize(1 to rows).map { i => + val rng = new java.util.Random(i) Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble)) }.map(v => new TestRow(v)) spark.createDataFrame(rdd) @@ -60,7 +58,7 @@ object LDASuite { } -class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LDASuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -185,16 +183,11 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.topicsMatrix.numCols === k) assert(!model.isDistributed) - // transform() - val transformed = model.transform(dataset) - val expectedColumns = Array("features", lda.getTopicDistributionCol) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - transformed.select(lda.getTopicDistributionCol).collect().foreach { r => - val topicDistribution = r.getAs[Vector](0) - assert(topicDistribution.size === k) - assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) + testTransformer[Tuple1[Vector]](dataset.toDF(), model, + "features", lda.getTopicDistributionCol) { + case Row(_, topicDistribution: Vector) => + assert(topicDistribution.size === k) + assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) } // logLikelihood, logPerplexity @@ -252,6 +245,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, LDASuite.allParamSettings, checkModelData) + + // Make sure the result is deterministic after saving and loading the model + val model = lda.fit(dataset) + val model2 = testDefaultReadWrite(model) + assert(model.logLikelihood(dataset) ~== model2.logLikelihood(dataset) absTol 1e-6) + assert(model.logPerplexity(dataset) ~== model2.logPerplexity(dataset) absTol 1e-6) } test("read/write DistributedLDAModel") { @@ -323,4 +322,21 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getOptimizer === optimizer) } } + + test("LDA with Array input") { + def trainAndLogLikelihoodAndPerplexity(dataset: Dataset[_]): (Double, Double) = { + val model = new LDA().setK(k).setOptimizer("online").setMaxIter(1).setSeed(1).fit(dataset) + (model.logLikelihood(dataset), model.logPerplexity(dataset)) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val (ll, lp) = trainAndLogLikelihoodAndPerplexity(newDataset) + val (llD, lpD) = trainAndLogLikelihoodAndPerplexity(newDatasetD) + val (llF, lpF) = trainAndLogLikelihoodAndPerplexity(newDatasetF) + // TODO: need to compare the results once we fix the seed issue for LDA (SPARK-22210) + assert(llD <= 0.0 && llD != Double.NegativeInfinity) + assert(llF <= 0.0 && llF != Double.NegativeInfinity) + assert(lpD >= 0.0 && lpD != Double.NegativeInfinity) + assert(lpF >= 0.0 && lpF != Double.NegativeInfinity) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 65328df17baff..55b460f1a4524 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -22,14 +22,16 @@ import scala.collection.mutable import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types._ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var data: Dataset[_] = _ final val r1 = 1.0 final val n1 = 10 @@ -48,10 +50,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite assert(pic.getK === 2) assert(pic.getMaxIter === 20) assert(pic.getInitMode === "random") - assert(pic.getPredictionCol === "prediction") - assert(pic.getIdCol === "id") - assert(pic.getNeighborsCol === "neighbors") - assert(pic.getSimilaritiesCol === "similarities") + assert(pic.getSrcCol === "src") + assert(pic.getDstCol === "dst") + assert(!pic.isDefined(pic.weightCol)) } test("parameter validation") { @@ -62,125 +63,110 @@ class PowerIterationClusteringSuite extends SparkFunSuite new PowerIterationClustering().setInitMode("no_such_a_mode") } intercept[IllegalArgumentException] { - new PowerIterationClustering().setIdCol("") - } - intercept[IllegalArgumentException] { - new PowerIterationClustering().setNeighborsCol("") + new PowerIterationClustering().setSrcCol("") } intercept[IllegalArgumentException] { - new PowerIterationClustering().setSimilaritiesCol("") + new PowerIterationClustering().setDstCol("") } } test("power iteration clustering") { val n = n1 + n2 - val model = new PowerIterationClustering() + val assignments = new PowerIterationClustering() .setK(2) .setMaxIter(40) - val result = model.transform(data) + .setWeightCol("weight") + .assignClusters(data) + .select("id", "cluster") + .as[(Long, Int)] + .collect() val predictions = Array.fill(2)(mutable.Set.empty[Long]) - result.select("id", "prediction").collect().foreach { - case Row(id: Long, cluster: Integer) => predictions(cluster) += id + assignments.foreach { + case (id, cluster) => predictions(cluster) += id } - assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + assert(predictions.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) - val result2 = new PowerIterationClustering() + val assignments2 = new PowerIterationClustering() .setK(2) .setMaxIter(10) .setInitMode("degree") - .transform(data) + .setWeightCol("weight") + .assignClusters(data) + .select("id", "cluster") + .as[(Long, Int)] + .collect() + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) - result2.select("id", "prediction").collect().foreach { - case Row(id: Long, cluster: Integer) => predictions2(cluster) += id + assignments2.foreach { + case (id, cluster) => predictions2(cluster) += id } - assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + assert(predictions2.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) } test("supported input types") { - val model = new PowerIterationClustering() + val pic = new PowerIterationClustering() .setK(2) .setMaxIter(1) + .setWeightCol("weight") - def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = { + def runTest(srcType: DataType, dstType: DataType, weightType: DataType): Unit = { val typedData = data.select( - col("id").cast(idType).alias("id"), - col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"), - col("similarities").cast(ArrayType(similarityType, containsNull = false)) - .alias("similarities") + col("src").cast(srcType).alias("src"), + col("dst").cast(dstType).alias("dst"), + col("weight").cast(weightType).alias("weight") ) - model.transform(typedData).collect() - } - - for (idType <- Seq(IntegerType, LongType)) { - runTest(idType, LongType, DoubleType) - } - for (neighborType <- Seq(IntegerType, LongType)) { - runTest(LongType, neighborType, DoubleType) - } - for (similarityType <- Seq(FloatType, DoubleType)) { - runTest(LongType, LongType, similarityType) + pic.assignClusters(typedData).collect() } - } - test("invalid input: wrong types") { - val model = new PowerIterationClustering() - .setK(2) - .setMaxIter(1) - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id").cast(DoubleType).alias("id"), - col("neighbors"), - col("similarities") - ) - model.transform(typedData) + for (srcType <- Seq(IntegerType, LongType)) { + runTest(srcType, LongType, DoubleType) } - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id"), - col("neighbors").cast(ArrayType(DoubleType, containsNull = false)).alias("neighbors"), - col("similarities") - ) - model.transform(typedData) + for (dstType <- Seq(IntegerType, LongType)) { + runTest(LongType, dstType, DoubleType) } - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id"), - col("neighbors"), - col("neighbors").alias("similarities") - ) - model.transform(typedData) + for (weightType <- Seq(FloatType, DoubleType)) { + runTest(LongType, LongType, weightType) } } test("invalid input: negative similarity") { - val model = new PowerIterationClustering() + val pic = new PowerIterationClustering() .setMaxIter(1) + .setWeightCol("weight") val badData = spark.createDataFrame(Seq( - (0, Array(1), Array(-1.0)), - (1, Array(0), Array(-1.0)) - )).toDF("id", "neighbors", "similarities") + (0, 1, -1.0), + (1, 0, -1.0) + )).toDF("src", "dst", "weight") val msg = intercept[SparkException] { - model.transform(badData) + pic.assignClusters(badData) }.getCause.getMessage assert(msg.contains("Similarity must be nonnegative")) } - test("invalid input: mismatched lengths for neighbor and similarity arrays") { - val model = new PowerIterationClustering() - .setMaxIter(1) - val badData = spark.createDataFrame(Seq( - (0, Array(1), Array(0.5)), - (1, Array(0, 2), Array(0.5)), - (2, Array(1), Array(0.5)) - )).toDF("id", "neighbors", "similarities") - val msg = intercept[SparkException] { - model.transform(badData) - }.getCause.getMessage - assert(msg.contains("The length of the neighbor ID list must be equal to the the length of " + - "the neighbor similarity list.")) - assert(msg.contains(s"Row for ID ${model.getIdCol}=1")) + test("test default weight") { + val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst) + + val assignments = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + .assignClusters(dataWithoutWeight) + val localAssignments = assignments + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + + val dataWithWeightOne = dataWithoutWeight.withColumn("weight", lit(1.0)) + + val assignments2 = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + .assignClusters(dataWithWeightOne) + val localAssignments2 = assignments2 + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + + assert(localAssignments === localAssignments2) } test("read/write") { @@ -188,10 +174,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setK(4) .setMaxIter(100) .setInitMode("degree") - .setIdCol("test_id") - .setNeighborsCol("myNeighborsCol") - .setSimilaritiesCol("mySimilaritiesCol") - .setPredictionCol("test_prediction") + .setSrcCol("src1") + .setDstCol("dst1") + .setWeightCol("weight") testDefaultReadWrite(t) } } @@ -222,17 +207,13 @@ object PowerIterationClusteringSuite { val n = n1 + n2 val points = genCircle(r1, n1) ++ genCircle(r2, n2) - val rows = for (i <- 1 until n) yield { - val neighbors = for (j <- 0 until i) yield { - j.toLong + val rows = (for (i <- 1 until n) yield { + for (j <- 0 until i) yield { + (i.toLong, j.toLong, sim(points(i), points(j))) } - val similarities = for (j <- 0 until i) yield { - sim(points(i), points(j)) - } - (i.toLong, neighbors.toArray, similarities.toArray) - } + }).flatMap(_.iterator) - spark.createDataFrame(rows).toDF("id", "neighbors", "similarities") + spark.createDataFrame(rows).toDF("src", "dst", "weight") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 21259a50916d2..20972d1f403b9 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -65,6 +65,57 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { testStopWordsRemover(remover, dataSet) } + test("StopWordsRemover with localed input (case insensitive)") { + val stopWords = Array("milk", "cookie") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + .setCaseSensitive(false) + .setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's. + val dataSet = Seq( + // scalastyle:off + (Seq("mİlk", "and", "nuts"), Seq("and", "nuts")), + // scalastyle:on + (Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")), + (Seq(null), Seq(null)), + (Seq(), Seq()) + ).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with localed input (case sensitive)") { + val stopWords = Array("milk", "cookie") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + .setCaseSensitive(true) + .setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's. + val dataSet = Seq( + // scalastyle:off + (Seq("mİlk", "and", "nuts"), Seq("mİlk", "and", "nuts")), + // scalastyle:on + (Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")), + (Seq(null), Seq(null)), + (Seq(), Seq()) + ).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with invalid locale") { + intercept[IllegalArgumentException] { + val stopWords = Array("test", "a", "an", "the") + new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + .setLocale("rt") // invalid locale + } + } + test("StopWordsRemover case sensitive") { val remover = new StopWordsRemover() .setInputCol("raw") diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala new file mode 100644 index 0000000000000..2252151af306b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.fpm + +import org.apache.spark.ml.util.MLTest +import org.apache.spark.sql.DataFrame + +class PrefixSpanSuite extends MLTest { + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("PrefixSpan projections with multiple partial starts") { + val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence") + val result = new PrefixSpan() + .setMinSupport(1.0) + .setMaxPatternLength(2) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(smallDataset) + .as[(Seq[Seq[Int]], Long)].collect() + val expected = Array( + (Seq(Seq(1)), 1L), + (Seq(Seq(1, 2)), 1L), + (Seq(Seq(1), Seq(1)), 1L), + (Seq(Seq(1), Seq(2)), 1L), + (Seq(Seq(1), Seq(3)), 1L), + (Seq(Seq(1, 3)), 1L), + (Seq(Seq(2)), 1L), + (Seq(Seq(2, 3)), 1L), + (Seq(Seq(2), Seq(1)), 1L), + (Seq(Seq(2), Seq(2)), 1L), + (Seq(Seq(2), Seq(3)), 1L), + (Seq(Seq(3)), 1L)) + compareResults[Int](expected, result) + } + + /* + To verify expected results for `smallTestData`, create file "prefixSpanSeqs2" with content + (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)): + 1 1 2 1 2 + 1 2 1 3 + 2 1 1 1 + 2 2 2 3 2 + 2 3 2 1 2 + 3 1 2 1 2 + 3 2 1 5 + 4 1 1 6 + In R, run: + library("arulesSequences") + prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade(prefixSpanSeqs, + parameter = 0.5, maxlen = 5 )) + resSeq = as(freqItemSeq, "data.frame") + resSeq + + sequence support + 1 <{1}> 0.75 + 2 <{2}> 0.75 + 3 <{3}> 0.50 + 4 <{1},{3}> 0.50 + 5 <{1,2}> 0.75 + */ + val smallTestData = Seq( + Seq(Seq(1, 2), Seq(3)), + Seq(Seq(1), Seq(3, 2), Seq(1, 2)), + Seq(Seq(1, 2), Seq(5)), + Seq(Seq(6))) + + val smallTestDataExpectedResult = Array( + (Seq(Seq(1)), 3L), + (Seq(Seq(2)), 3L), + (Seq(Seq(3)), 2L), + (Seq(Seq(1), Seq(3)), 2L), + (Seq(Seq(1, 2)), 3L) + ) + + test("PrefixSpan Integer type, variable-size itemsets") { + val df = smallTestData.toDF("sequence") + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan input row with nulls") { + val df = (smallTestData :+ null).toDF("sequence") + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan String type, variable-size itemsets") { + val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap + val df = smallTestData + .map(seq => seq.map(itemSet => itemSet.map(intToString))) + .toDF("sequence") + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) + .as[(Seq[Seq[String]], Long)].collect() + + val expected = smallTestDataExpectedResult.map { case (seq, freq) => + (seq.map(itemSet => itemSet.map(intToString)), freq) + } + compareResults[String](expected, result) + } + + private def compareResults[Item]( + expectedValue: Array[(Seq[Seq[Item]], Long)], + actualValue: Array[(Seq[Seq[Item]], Long)]): Unit = { + val expectedSet = expectedValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + val actualSet = actualValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + assert(expectedSet === actualSet) + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index fad11d078250f..b145c7a3dc952 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -20,13 +20,15 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit import org.apache.spark.util.Utils /** @@ -201,9 +203,81 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { assert(mostImportantFeature !== mostIF) } + test("model evaluateEachIteration") { + for (lossType <- GBTRegressor.supportedLossTypes) { + val gbt = new GBTRegressor() + .setSeed(1L) + .setMaxDepth(2) + .setMaxIter(3) + .setLossType(lossType) + val model3 = gbt.fit(trainData.toDF) + val model1 = new GBTRegressionModel("gbt-reg-model-test1", + model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures) + val model2 = new GBTRegressionModel("gbt-reg-model-test2", + model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures) + + for (evalLossType <- GBTRegressor.supportedLossTypes) { + val evalArr = model3.evaluateEachIteration(validationData.toDF, evalLossType) + val lossErr1 = GradientBoostedTrees.computeError(validationData, + model1.trees, model1.treeWeights, model1.convertToOldLossType(evalLossType)) + val lossErr2 = GradientBoostedTrees.computeError(validationData, + model2.trees, model2.treeWeights, model2.convertToOldLossType(evalLossType)) + val lossErr3 = GradientBoostedTrees.computeError(validationData, + model3.trees, model3.treeWeights, model3.convertToOldLossType(evalLossType)) + + assert(evalArr(0) ~== lossErr1 relTol 1E-3) + assert(evalArr(1) ~== lossErr2 relTol 1E-3) + assert(evalArr(2) ~== lossErr3 relTol 1E-3) + } + } + } + test("runWithValidation stops early and performs better on a validation dataset") { + val validationIndicatorCol = "validationIndicator" + val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false)) + val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true)) - ///////////////////////////////////////////////////////////////////////////// + val numIter = 20 + for (lossType <- GBTRegressor.supportedLossTypes) { + val gbt = new GBTRegressor() + .setSeed(123) + .setMaxDepth(2) + .setLossType(lossType) + .setMaxIter(numIter) + val modelWithoutValidation = gbt.fit(trainDF) + + gbt.setValidationIndicatorCol(validationIndicatorCol) + val modelWithValidation = gbt.fit(trainDF.union(validationDF)) + + assert(modelWithoutValidation.numTrees === numIter) + // early stop + assert(modelWithValidation.numTrees < numIter) + + val errorWithoutValidation = GradientBoostedTrees.computeError(validationData, + modelWithoutValidation.trees, modelWithoutValidation.treeWeights, + modelWithoutValidation.getOldLossType) + val errorWithValidation = GradientBoostedTrees.computeError(validationData, + modelWithValidation.trees, modelWithValidation.treeWeights, + modelWithValidation.getOldLossType) + + assert(errorWithValidation < errorWithoutValidation) + + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validationData, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, + OldAlgo.Regression) + assert(evaluationArray.length === numIter) + assert(evaluationArray(modelWithValidation.numTrees) > + evaluationArray(modelWithValidation.numTrees - 1)) + var i = 1 + while (i < modelWithValidation.numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index d5bcbb221783e..997c50157dcda 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -493,11 +493,20 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest } [1] -0.0457441 -0.6833928 [1] 1.8121235 -0.1747493 -0.5815417 + + R code for deivance calculation: + data = cbind(y=c(0,1,0,0,0,1), x1=c(18, 12, 15, 13, 15, 16), x2=c(1,0,0,2,1,1)) + summary(glm(y~x1+x2, family=poisson, data=data.frame(data)))$deviance + [1] 3.70055 + summary(glm(y~x1+x2-1, family=poisson, data=data.frame(data)))$deviance + [1] 3.809296 */ val expected = Seq( Vectors.dense(0.0, -0.0457441, -0.6833928), Vectors.dense(1.8121235, -0.1747493, -0.5815417)) + val residualDeviancesR = Array(3.809296, 3.70055) + import GeneralizedLinearRegression._ var idx = 0 @@ -510,6 +519,7 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + s"$link link and fitIntercept = $fitIntercept (with zero values).") + assert(model.summary.deviance ~== residualDeviancesR(idx) absTol 1E-3) idx += 1 } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 15dade2627090..e6ee7220d2279 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -25,17 +25,17 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator} import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class CrossValidatorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -66,6 +66,13 @@ class CrossValidatorSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(cvModel.avgMetrics.length === lrParamMaps.length) + + val result = cvModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), cvModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("cross validation with linear regression") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 9024342d9c831..cd76acf9c67bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -24,17 +24,17 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -64,6 +64,13 @@ class TrainValidationSplitSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(tvsModel.validationMetrics.length === lrParamMaps.length) + + val result = tvsModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), tvsModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("train validation with linear regression") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 4da95e74434ee..4d9e664850c12 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -19,9 +19,10 @@ package org.apache.spark.ml.util import java.io.{File, IOException} +import org.json4s.JNothing import org.scalatest.Suite -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -129,6 +130,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") + final val shouldNotSetIfSetintParamWithDefault: IntParam = + new IntParam(this, "shouldNotSetIfSetintParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") @@ -150,6 +153,13 @@ class MyParams(override val uid: String) extends Params with MLWritable { set(doubleArrayParam -> Array(8.0, 9.0)) set(stringArrayParam -> Array("10", "11")) + def checkExclusiveParams(): Unit = { + if (isSet(shouldNotSetIfSetintParamWithDefault) && isSet(intParamWithDefault)) { + throw new SparkException("intParamWithDefault and shouldNotSetIfSetintParamWithDefault " + + "shouldn't be set at the same time") + } + } + override def copy(extra: ParamMap): Params = defaultCopy(extra) override def write: MLWriter = new DefaultParamsWriter(this) @@ -169,4 +179,65 @@ class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext val myParams = new MyParams("my_params") testDefaultReadWrite(myParams) } + + test("default param shouldn't become user-supplied param after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.shouldNotSetIfSetintParamWithDefault, 1) + myParams.checkExclusiveParams() + val loadedMyParams = testDefaultReadWrite(myParams) + loadedMyParams.checkExclusiveParams() + assert(loadedMyParams.getDefault(loadedMyParams.intParamWithDefault) == + myParams.getDefault(myParams.intParamWithDefault)) + + loadedMyParams.set(myParams.intParamWithDefault, 1) + intercept[SparkException] { + loadedMyParams.checkExclusiveParams() + } + } + + test("User-supplied value for default param should be kept after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.intParamWithDefault, 100) + val loadedMyParams = testDefaultReadWrite(myParams) + assert(loadedMyParams.get(myParams.intParamWithDefault).get == 100) + } + + test("Read metadata without default field prior to 2.4") { + // default params are saved in `paramMap` field in metadata file prior to Spark 2.4. + val metadata = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.3.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata = DefaultParamsReader.parseMetadata(metadata) + val myParams = new MyParams("my_params") + assert(!myParams.isSet(myParams.intParamWithDefault)) + parsedMetadata.getAndSetParams(myParams) + + // The behavior prior to Spark 2.4, default params are set in loaded ML instance. + assert(myParams.isSet(myParams.intParamWithDefault)) + } + + test("Should raise error when read metadata without default field after Spark 2.4") { + val myParams = new MyParams("my_params") + + val metadata1 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.4.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata1 = DefaultParamsReader.parseMetadata(metadata1) + val err1 = intercept[IllegalArgumentException] { + parsedMetadata1.getAndSetParams(myParams) + } + assert(err1.getMessage().contains("Cannot recognize JSON metadata")) + + val metadata2 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"3.0.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata2 = DefaultParamsReader.parseMetadata(metadata2) + val err2 = intercept[IllegalArgumentException] { + parsedMetadata2.getAndSetParams(myParams) + } + assert(err2.getMessage().contains("Cannot recognize JSON metadata")) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index c328d81b4bc3a..5e72b4d864c1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} import org.apache.spark.ml.recommendation.{ALS, ALSModel} @@ -247,4 +247,25 @@ object MLTestingUtils extends SparkFunSuite { } models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)} } + + /** + * Helper function for testing different input types for "features" column. Given a DataFrame, + * generate three output DataFrames: one having vector "features" column with float precision, + * one having double array "features" column with float precision, and one having float array + * "features" column. + */ + def generateArrayFeatureDataset(dataset: Dataset[_], + featuresColName: String = "features"): (Dataset[_], Dataset[_], Dataset[_]) = { + val toFloatVectorUDF = udf { (features: Vector) => + Vectors.dense(features.toArray.map(_.toFloat.toDouble))} + val toDoubleArrayUDF = udf { (features: Vector) => features.toArray} + val toFloatArrayUDF = udf { (features: Vector) => features.toArray.map(_.toFloat)} + val newDataset = dataset.withColumn(featuresColName, toFloatVectorUDF(col(featuresColName))) + val newDatasetD = newDataset.withColumn(featuresColName, toDoubleArrayUDF(col(featuresColName))) + val newDatasetF = newDataset.withColumn(featuresColName, toFloatArrayUDF(col(featuresColName))) + assert(newDataset.schema(featuresColName).dataType.equals(new VectorUDT)) + assert(newDatasetD.schema(featuresColName).dataType.equals(new ArrayType(DoubleType, false))) + assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false))) + (newDataset, newDatasetD, newDatasetF) + } } diff --git a/pom.xml b/pom.xml index 0a711f287a53f..039292337eaa0 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql + external/avro @@ -129,8 +130,8 @@ 1.2.1 10.12.1.1 - 1.8.2 - 1.4.3 + 1.10.0 + 1.4.4 nohive 1.6.0 9.3.20.v20170531 @@ -155,7 +156,7 @@ 3.4.1 3.2.2 - 2.11.8 + 2.11.12 2.11 1.9.13 2.6.7 @@ -313,13 +314,13 @@ chill-java ${chill.version}
      - org.apache.xbean - xbean-asm5-shaded - 4.4 + xbean-asm6-shaded + 4.8 jline jline - 2.12.1 + 2.14.3 org.scalatest @@ -760,6 +761,12 @@ 1.10.19 test + + org.jmock + jmock-junit4 + test + 2.8.4 + org.scalacheck scalacheck_${scala.binary.version} @@ -1778,6 +1785,12 @@ parquet-hadoop ${parquet.version} ${parquet.deps.scope} + + + commons-pool + commons-pool + + org.apache.parquet @@ -2110,7 +2123,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.20.1 + 2.22.0 @@ -2594,6 +2607,28 @@ + + com.github.spotbugs + spotbugs-maven-plugin + 3.1.3 + + ${basedir}/target/scala-${scala.binary.version}/classes + ${basedir}/target/scala-${scala.binary.version}/test-classes + Max + Low + true + FindPuzzlers + false + + + + + check + + compile + + + @@ -2671,6 +2706,15 @@ + + hadoop-3.1 + + 3.1.0 + 2.12.0 + 3.4.9 + + + yarn @@ -2690,6 +2734,7 @@ kubernetes resource-managers/kubernetes/core + resource-managers/kubernetes/integration-tests @@ -2733,7 +2778,7 @@ scala-2.12 - 2.12.4 + 2.12.6 2.12 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a87fa68422c34..8f96bb0f33849 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,19 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-23528] Add numIter to ClusteringSummary + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.ClusteringSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.BisectingKMeansSummary.this"), + + // [SPARK-6237][NETWORK] Network-layer changes to allow stream upload + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.receive"), + + // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.this"), + // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), @@ -62,12 +75,29 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), + // [SPARK-23455][ML] Default Params in ML should be saved separately in metadata + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.paramMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$paramMap_="), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="), + // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"), + + // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol") ) // Exclude rules for 2.3.x diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7469f11df0294..247b6fee394bc 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -27,6 +27,7 @@ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion +import com.etsy.sbt.checkstyle.CheckstylePlugin.autoImport._ import com.simplytyped.Antlr4Plugin._ import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import com.typesafe.tools.mima.plugin.MimaKeys @@ -39,8 +40,8 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010, avro) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10", "avro" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq(streaming, streamingKafka010) = @@ -56,11 +57,11 @@ object BuildCommons { val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn, streamingFlumeSink, streamingFlume, streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, - dockerIntegrationTests, hadoopCloud) = + dockerIntegrationTests, hadoopCloud, kubernetesIntegrationTests) = Seq("kubernetes", "mesos", "yarn", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", - "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) + "docker-integration-tests", "hadoop-cloud", "kubernetes-integration-tests").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly") @@ -317,7 +318,7 @@ object SparkBuild extends PomBuild { /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools)) .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ - ExcludedDependencies.settings)) + ExcludedDependencies.settings ++ Checkstyle.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -325,7 +326,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, sqlKafka010, kvstore + unsafe, tags, sqlKafka010, kvstore, avro ).contains(x) } @@ -463,7 +464,8 @@ object DockerIntegrationTests { */ object DependencyOverrides { lazy val settings = Seq( - dependencyOverrides += "com.google.guava" % "guava" % "14.0.1") + dependencyOverrides += "com.google.guava" % "guava" % "14.0.1", + dependencyOverrides += "jline" % "jline" % "2.14.3") } /** @@ -686,9 +688,11 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + yarn, tags, streamingKafka010, sqlKafka010, avro), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + yarn, tags, streamingKafka010, sqlKafka010, avro), unidocAllClasspaths in (ScalaUnidoc, unidoc) := { ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value) @@ -728,7 +732,8 @@ object Unidoc { scalacOptions in (ScalaUnidoc, unidoc) ++= Seq( "-groups", // Group similar methods together based on the @group annotation. - "-skip-packages", "org.apache.hadoop" + "-skip-packages", "org.apache.hadoop", + "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath ) ++ ( // Add links to sources when generating Scaladoc for a non-snapshot release if (!isSnapshot.value) { @@ -740,6 +745,17 @@ object Unidoc { ) } +object Checkstyle { + lazy val settings = Seq( + checkstyleSeverityLevel := Some(CheckstyleSeverityLevel.Error), + javaSource in (Compile, checkstyle) := baseDirectory.value / "src/main/java", + javaSource in (Test, checkstyle) := baseDirectory.value / "src/test/java", + checkstyleConfigLocation := CheckstyleConfigLocation.File("dev/checkstyle.xml"), + checkstyleOutputFile := baseDirectory.value / "target/checkstyle-output.xml", + checkstyleOutputFile in Test := baseDirectory.value / "target/checkstyle-output.xml" + ) +} + object CopyDependencies { val copyDeps = TaskKey[Unit]("copyDeps", "Copies needed dependencies to the build directory.") diff --git a/project/build.properties b/project/build.properties index b19518fd7aa1c..d03985d980ec8 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.16 +sbt.version=0.13.17 diff --git a/project/plugins.sbt b/project/plugins.sbt index 96bdb9067ae59..ffbd417b0f145 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1,11 @@ +addSbtPlugin("com.etsy" % "sbt-checkstyle-plugin" % "3.1.1") + +// sbt-checkstyle-plugin uses an old version of checkstyle. Match it to Maven's. +libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.2" + +// checkstyle uses guava 23.0. +libraryDependencies += "com.google.guava" % "guava" % "23.0" + // need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") diff --git a/python/README.md b/python/README.md index 2e0112da58b94..c020d84b01ffd 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.6), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). +At its core PySpark depends on Py4J (currently version 0.10.7), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). diff --git a/python/docs/Makefile b/python/docs/Makefile index 09898f29950ed..1ed1f33af2326 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -1,18 +1,43 @@ # Makefile for Sphinx documentation # +ifndef SPHINXBUILD +ifndef SPHINXPYTHON +SPHINXBUILD = sphinx-build +endif +endif + +ifdef SPHINXBUILD +# User-friendly check for sphinx-build. +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) +$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif +else +# Note that there is an issue with Python version and Sphinx in PySpark documentation generation. +# Please remove this check below when this issue is fixed. See SPARK-24530 for more details. +PYTHON_VERSION_CHECK = $(shell $(SPHINXPYTHON) -c 'import sys; print(sys.version_info < (3, 0, 0))') +ifeq ($(PYTHON_VERSION_CHECK), True) +$(error Note that Python 3 is required to generate PySpark documentation correctly for now. Current Python executable was less than Python 3. See SPARK-24530. To force Sphinx to use a specific Python executable, please set SPHINXPYTHON to point to the Python 3 executable.) +endif +# Check if Sphinx is installed. +ifeq ($(shell $(SPHINXPYTHON) -c 'import sphinx' >/dev/null 2>&1; echo $$?), 1) +$(error Python executable '$(SPHINXPYTHON)' did not have Sphinx installed. Make sure you have Sphinx installed, then set the SPHINXPYTHON environment variable to point to the Python executable having Sphinx installed. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif +# Use 'SPHINXPYTHON -msphinx' instead of 'sphinx-build'. See https://github.com/sphinx-doc/sphinx/pull/3523 for more details. +SPHINXBUILD = $(SPHINXPYTHON) -msphinx +endif + # You can set these variables from the command line. SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build +# You can set SPHINXBUILD to specify Sphinx build executable or SPHINXPYTHON to specify the Python executable used in Sphinx. +# They follow: +# 1. if SPHINXPYTHON is set, use Python. If SPHINXBUILD is set, use sphinx-build. +# 2. If both are set, SPHINXBUILD has a higher priority over SPHINXPYTHON +# 3. By default, SPHINXBUILD is used as 'sphinx-build'. -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.6-src.zip) - -# User-friendly check for sphinx-build -ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) -$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) -endif +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip) # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 diff --git a/python/lib/py4j-0.10.6-src.zip b/python/lib/py4j-0.10.6-src.zip deleted file mode 100644 index 2f8edcc0c0b88..0000000000000 Binary files a/python/lib/py4j-0.10.6-src.zip and /dev/null differ diff --git a/python/lib/py4j-0.10.7-src.zip b/python/lib/py4j-0.10.7-src.zip new file mode 100644 index 0000000000000..128e321078793 Binary files /dev/null and b/python/lib/py4j-0.10.7-src.zip differ diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index ea845b98b3db2..88519d7311fcc 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -272,7 +272,7 @@ def save_memoryview(self, obj): if not PY3: def save_buffer(self, obj): self.save(str(obj)) - dispatch[buffer] = save_buffer + dispatch[buffer] = save_buffer # noqa: F821 'buffer' was removed in Python 3 def save_unsupported(self, obj): raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) @@ -801,10 +801,10 @@ def save_ellipsis(self, obj): def save_not_implemented(self, obj): self.save_reduce(_gen_not_implemented, ()) - if PY3: - dispatch[io.TextIOWrapper] = save_file - else: + try: # Python 2 dispatch[file] = save_file + except NameError: # Python 3 + dispatch[io.TextIOWrapper] = save_file dispatch[type(Ellipsis)] = save_ellipsis dispatch[type(NotImplemented)] = save_not_implemented @@ -819,6 +819,11 @@ def save_logger(self, obj): dispatch[logging.Logger] = save_logger + def save_root_logger(self, obj): + self.save_reduce(logging.getLogger, (), obj=obj) + + dispatch[logging.RootLogger] = save_root_logger + """Special functions for Add-on libraries""" def inject_addons(self): """Plug in system. Register additional pickling functions if modules already loaded""" diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 7c664966ed74e..2cb3117184334 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -211,9 +211,21 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: - self._python_includes.append(filename) - sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) + try: + filepath = os.path.join(SparkFiles.getRootDirectory(), filename) + if not os.path.exists(filepath): + # In case of YARN with shell mode, 'spark.submit.pyFiles' files are + # not added via SparkContext.addFile. Here we check if the file exists, + # try to copy and then add it to the path. See SPARK-21945. + shutil.copyfile(path, filepath) + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: + self._python_includes.append(filename) + sys.path.insert(1, filepath) + except Exception: + warnings.warn( + "Failed to add file [%s] speficied in 'spark.submit.pyFiles' to " + "Python path:\n %s" % (path, "\n ".join(sys.path)), + RuntimeWarning) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) @@ -835,6 +847,8 @@ def addFile(self, path, recursive=False): A directory can be given if the recursive option is set to True. Currently directories are only supported for Hadoop-supported filesystems. + .. note:: A path can be added only once. Subsequent additions of the same path are ignored. + >>> from pyspark import SparkFiles >>> path = os.path.join(tempdir, "test.txt") >>> with open(path, "w") as testFile: @@ -855,6 +869,8 @@ def addPyFile(self, path): SparkContext in the future. The C{path} passed can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, HTTPS or FTP URI. + + .. note:: A path can be added only once. Subsequent additions of the same path are ignored. """ self.addFile(path) (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix @@ -998,8 +1014,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) - return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) + sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) + return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer)) def show_profiles(self): """ Print the profile stats to stdout """ diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 7bed5216eabf3..ebdd665e349c5 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -29,7 +29,7 @@ from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT from pyspark.worker import main as worker_main -from pyspark.serializers import read_int, write_int +from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer def compute_real_exit_code(exit_code): @@ -40,7 +40,7 @@ def compute_real_exit_code(exit_code): return 1 -def worker(sock): +def worker(sock, authenticated): """ Called by a worker process after the fork(). """ @@ -56,6 +56,18 @@ def worker(sock): # otherwise writes also cause a seek that makes us miss data on the read side. infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536) outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536) + + if not authenticated: + client_secret = UTF8Deserializer().loads(infile) + if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret: + write_with_length("ok".encode("utf-8"), outfile) + outfile.flush() + else: + write_with_length("err".encode("utf-8"), outfile) + outfile.flush() + sock.close() + return 1 + exit_code = 0 try: worker_main(infile, outfile) @@ -153,8 +165,11 @@ def handle_sigterm(*args): write_int(os.getpid(), outfile) outfile.flush() outfile.close() + authenticated = False while True: - code = worker(sock) + code = worker(sock, authenticated) + if code == 0: + authenticated = True if not reuse or code: # wait for closing try: diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3e704fe9bf6ec..fa2d5e8db716a 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -21,16 +21,19 @@ import select import signal import shlex +import shutil import socket import platform +import tempfile +import time from subprocess import Popen, PIPE if sys.version >= '3': xrange = range -from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters from pyspark.find_spark_home import _find_spark_home -from pyspark.serializers import read_int +from pyspark.serializers import read_int, write_with_length, UTF8Deserializer def launch_gateway(conf=None): @@ -41,6 +44,7 @@ def launch_gateway(conf=None): """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) + gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the @@ -59,40 +63,40 @@ def launch_gateway(conf=None): ]) command = command + shlex.split(submit_args) - # Start a socket that will be used by PythonGatewayServer to communicate its port to us - callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - callback_socket.bind(('127.0.0.1', 0)) - callback_socket.listen(1) - callback_host, callback_port = callback_socket.getsockname() - env = dict(os.environ) - env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host - env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) - - # Launch the Java gateway. - # We open a pipe to stdin so that the Java gateway can die when the pipe is broken - if not on_windows: - # Don't send ctrl-c / SIGINT to the Java gateway: - def preexec_func(): - signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) - else: - # preexec_fn not supported on Windows - proc = Popen(command, stdin=PIPE, env=env) - - gateway_port = None - # We use select() here in order to avoid blocking indefinitely if the subprocess dies - # before connecting - while gateway_port is None and proc.poll() is None: - timeout = 1 # (seconds) - readable, _, _ = select.select([callback_socket], [], [], timeout) - if callback_socket in readable: - gateway_connection = callback_socket.accept()[0] - # Determine which ephemeral port the server started on: - gateway_port = read_int(gateway_connection.makefile(mode="rb")) - gateway_connection.close() - callback_socket.close() - if gateway_port is None: - raise Exception("Java gateway process exited before sending the driver its port number") + # Create a temporary directory where the gateway server should write the connection + # information. + conn_info_dir = tempfile.mkdtemp() + try: + fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir) + os.close(fd) + os.unlink(conn_info_file) + + env = dict(os.environ) + env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file + + # Launch the Java gateway. + # We open a pipe to stdin so that the Java gateway can die when the pipe is broken + if not on_windows: + # Don't send ctrl-c / SIGINT to the Java gateway: + def preexec_func(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) + else: + # preexec_fn not supported on Windows + proc = Popen(command, stdin=PIPE, env=env) + + # Wait for the file to appear, or for the process to exit, whichever happens first. + while not proc.poll() and not os.path.isfile(conn_info_file): + time.sleep(0.1) + + if not os.path.isfile(conn_info_file): + raise Exception("Java gateway process exited before sending its port number") + + with open(conn_info_file, "rb") as info: + gateway_port = read_int(info) + gateway_secret = UTF8Deserializer().loads(info) + finally: + shutil.rmtree(conn_info_dir) # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when @@ -111,7 +115,9 @@ def killChild(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) + gateway = JavaGateway( + gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, + auto_convert=True)) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") @@ -126,3 +132,39 @@ def killChild(): java_import(gateway.jvm, "scala.Tuple2") return gateway + + +def do_server_auth(conn, auth_secret): + """ + Performs the authentication protocol defined by the SocketAuthHelper class on the given + file-like object 'conn'. + """ + write_with_length(auth_secret.encode("utf-8"), conn) + conn.flush() + reply = UTF8Deserializer().loads(conn) + if reply != "ok": + conn.close() + raise Exception("Unexpected reply from iterator server.") + + +def ensure_callback_server_started(gw): + """ + Start callback server if not already started. The callback server is needed if the Java + driver process needs to callback into the Python driver process to execute Python code. + """ + + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__ or gw._callback_server is None: + gw.callback_server_parameters.eager_load = True + gw.callback_server_parameters.daemonize = True + gw.callback_server_parameters.daemonize_connections = True + gw.callback_server_parameters.port = 0 + gw.start_callback_server(gw.callback_server_parameters) + cbport = gw._callback_server.server_socket.getsockname()[1] + gw._callback_server.port = cbport + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 129d7d68f7cbb..d99a25390db15 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -21,5 +21,11 @@ """ from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer from pyspark.ml.pipeline import Pipeline, PipelineModel +from pyspark.ml import classification, clustering, evaluation, feature, fpm, \ + image, pipeline, recommendation, regression, stat, tuning, util, linalg, param -__all__ = ["Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel"] +__all__ = [ + "Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel", + "classification", "clustering", "evaluation", "feature", "fpm", "image", + "recommendation", "regression", "stat", "tuning", "util", "linalg", "param", +] diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ec17653a1adf9..d5963f4f7042c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -239,6 +239,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti True >>> blorModel.intercept == model2.intercept True + >>> model2 + LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2 .. versionadded:: 1.3.0 """ @@ -562,6 +564,9 @@ def evaluate(self, dataset): java_blr_summary = self._call_java("evaluate", dataset) return BinaryLogisticRegressionSummary(java_blr_summary) + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionSummary(JavaWrapper): """ @@ -1131,6 +1136,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestClassificationModel(java_model) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): @@ -1193,6 +1205,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) + >>> gbt.getFeatureSubsetStrategy() + 'all' >>> model = gbt.fit(td) >>> model.featureImportances SparseVector(1, {0: 1.0}) @@ -1222,6 +1236,12 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)], + ... ["indexed", "features"]) + >>> model.evaluateEachIteration(validation) + [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] + >>> model.numClasses + 2 .. versionadded:: 1.4.0 """ @@ -1240,19 +1260,22 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", - maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0): + maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, + featureSubsetStrategy="all"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ + featureSubsetStrategy="all") """ super(GBTClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.GBTClassifier", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0, + featureSubsetStrategy="all") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1261,12 +1284,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0): + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, + featureSubsetStrategy="all"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ + featureSubsetStrategy="all") Sets params for Gradient Boosted Tree Classification. """ kwargs = self._input_kwargs @@ -1289,8 +1314,15 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + -class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, +class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTClassifier. @@ -1319,6 +1351,17 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + return self._call_java("evaluateEachIteration", dataset) + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b3d5fb17f6b81..2f0660040dc7c 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -19,14 +19,15 @@ from pyspark import since, keyword_only from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc +from pyspark.sql import DataFrame __all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary', 'KMeans', 'KMeansModel', 'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary', - 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel'] + 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel', 'PowerIterationClustering'] class ClusteringSummary(JavaWrapper): @@ -348,8 +349,8 @@ def summary(self): @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, - JavaMLWritable, JavaMLReadable): +class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol, HasMaxIter, + HasTol, HasSeed, JavaMLWritable, JavaMLReadable): """ K-means clustering with a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). @@ -405,9 +406,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol typeConverter=TypeConverters.toString) initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt) - distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + - "Supported options: 'euclidean' and 'cosine'.", - typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, @@ -543,8 +541,8 @@ def summary(self): @inherit_doc -class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed, - JavaMLWritable, JavaMLReadable): +class BisectingKMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol, + HasMaxIter, HasSeed, JavaMLWritable, JavaMLReadable): """ A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" by Steinbach, Karypis, and Kumar, with modification to fit Spark. @@ -584,6 +582,8 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> bkm2 = BisectingKMeans.load(bkm_path) >>> bkm2.getK() 2 + >>> bkm2.getDistanceMeasure() + 'euclidean' >>> model_path = temp_path + "/bkm_model" >>> model.save(model_path) >>> model2 = BisectingKMeansModel.load(model_path) @@ -606,10 +606,10 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, - seed=None, k=4, minDivisibleClusterSize=1.0): + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean"): """ __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ - seed=None, k=4, minDivisibleClusterSize=1.0) + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean") """ super(BisectingKMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans", @@ -621,10 +621,10 @@ def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=2 @keyword_only @since("2.0.0") def setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, - seed=None, k=4, minDivisibleClusterSize=1.0): + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean"): """ setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ - seed=None, k=4, minDivisibleClusterSize=1.0) + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean") Sets params for BisectingKMeans. """ kwargs = self._input_kwargs @@ -658,6 +658,20 @@ def getMinDivisibleClusterSize(self): """ return self.getOrDefault(self.minDivisibleClusterSize) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` or its default value. + """ + return self.getOrDefault(self.distanceMeasure) + def _create_model(self, java_model): return BisectingKMeansModel(java_model) @@ -836,7 +850,7 @@ class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInter Terminology: - - "term" = "word": an el + - "term" = "word": an element of the vocabulary - "token": instance of a term appearing in a document - "topic": multinomial distribution over terms representing some concept - "document": one piece of text, corresponding to one row in the input data @@ -938,7 +952,7 @@ def __init__(self, featuresCol="features", maxIter=20, seed=None, checkpointInte k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\ subsamplingRate=0.05, optimizeDocConcentration=True,\ docConcentration=None, topicConcentration=None,\ - topicDistributionCol="topicDistribution", keepLastCheckpoint=True): + topicDistributionCol="topicDistribution", keepLastCheckpoint=True) """ super(LDA, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.LDA", self.uid) @@ -967,7 +981,7 @@ def setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInt k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\ subsamplingRate=0.05, optimizeDocConcentration=True,\ docConcentration=None, topicConcentration=None,\ - topicDistributionCol="topicDistribution", keepLastCheckpoint=True): + topicDistributionCol="topicDistribution", keepLastCheckpoint=True) Sets params for LDA. """ @@ -1156,10 +1170,189 @@ def getKeepLastCheckpoint(self): return self.getOrDefault(self.keepLastCheckpoint) +@inherit_doc +class PowerIterationClustering(HasMaxIter, HasWeightCol, JavaParams, JavaMLReadable, + JavaMLWritable): + """ + .. note:: Experimental + + Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + Lin and Cohen. From the abstract: + PIC finds a very low-dimensional embedding of a dataset using truncated power + iteration on a normalized pair-wise similarity matrix of the data. + + This class is not yet an Estimator/Transformer, use :py:func:`assignClusters` method + to run the PowerIterationClustering algorithm. + + .. seealso:: `Wikipedia on Spectral clustering \ + `_ + + >>> data = [(1, 0, 0.5), \ + (2, 0, 0.5), (2, 1, 0.7), \ + (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9), \ + (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1), \ + (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)] + >>> df = spark.createDataFrame(data).toDF("src", "dst", "weight") + >>> pic = PowerIterationClustering(k=2, maxIter=40, weightCol="weight") + >>> assignments = pic.assignClusters(df) + >>> assignments.sort(assignments.id).show(truncate=False) + +---+-------+ + |id |cluster| + +---+-------+ + |0 |1 | + |1 |1 | + |2 |1 | + |3 |1 | + |4 |1 | + |5 |0 | + +---+-------+ + ... + >>> pic_path = temp_path + "/pic" + >>> pic.save(pic_path) + >>> pic2 = PowerIterationClustering.load(pic_path) + >>> pic2.getK() + 2 + >>> pic2.getMaxIter() + 40 + + .. versionadded:: 2.4.0 + """ + + k = Param(Params._dummy(), "k", + "The number of clusters to create. Must be > 1.", + typeConverter=TypeConverters.toInt) + initMode = Param(Params._dummy(), "initMode", + "The initialization algorithm. This can be either " + + "'random' to use a random vector as vertex properties, or 'degree' to use " + + "a normalized sum of similarities with other vertices. Supported options: " + + "'random' and 'degree'.", + typeConverter=TypeConverters.toString) + srcCol = Param(Params._dummy(), "srcCol", + "Name of the input column for source vertex IDs.", + typeConverter=TypeConverters.toString) + dstCol = Param(Params._dummy(), "dstCol", + "Name of the input column for destination vertex IDs.", + typeConverter=TypeConverters.toString) + + @keyword_only + def __init__(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", + weightCol=None): + """ + __init__(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\ + weightCol=None) + """ + super(PowerIterationClustering, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid) + self._setDefault(k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.4.0") + def setParams(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", + weightCol=None): + """ + setParams(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\ + weightCol=None) + Sets params for PowerIterationClustering. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.4.0") + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + """ + return self._set(k=value) + + @since("2.4.0") + def getK(self): + """ + Gets the value of :py:attr:`k` or its default value. + """ + return self.getOrDefault(self.k) + + @since("2.4.0") + def setInitMode(self, value): + """ + Sets the value of :py:attr:`initMode`. + """ + return self._set(initMode=value) + + @since("2.4.0") + def getInitMode(self): + """ + Gets the value of :py:attr:`initMode` or its default value. + """ + return self.getOrDefault(self.initMode) + + @since("2.4.0") + def setSrcCol(self, value): + """ + Sets the value of :py:attr:`srcCol`. + """ + return self._set(srcCol=value) + + @since("2.4.0") + def getSrcCol(self): + """ + Gets the value of :py:attr:`srcCol` or its default value. + """ + return self.getOrDefault(self.srcCol) + + @since("2.4.0") + def setDstCol(self, value): + """ + Sets the value of :py:attr:`dstCol`. + """ + return self._set(dstCol=value) + + @since("2.4.0") + def getDstCol(self): + """ + Gets the value of :py:attr:`dstCol` or its default value. + """ + return self.getOrDefault(self.dstCol) + + @since("2.4.0") + def assignClusters(self, dataset): + """ + Run the PIC algorithm and returns a cluster assignment for each input vertex. + + :param dataset: + A dataset with columns src, dst, weight representing the affinity matrix, + which is the matrix A in the PIC paper. Suppose the src column value is i, + the dst column value is j, the weight column value is similarity s,,ij,, + which must be nonnegative. This is a symmetric matrix and hence + s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be + either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are + ignored, because we assume s,,ij,, = 0.0. + + :return: + A dataset that contains columns of vertex id and the corresponding cluster for + the id. The schema of it will be: + - id: Long + - cluster: Int + + .. versionadded:: 2.4.0 + """ + self._transfer_params_to_java() + jdf = self._java_obj.assignClusters(dataset._jdf) + return DataFrame(jdf, dataset.sql_ctx) + + if __name__ == "__main__": import doctest + import numpy import pyspark.ml.clustering from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.ml.clustering.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index cdda30cfab482..ddba7389145e3 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1294,14 +1294,14 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed, >>> mh = MinHashLSH(inputCol="features", outputCol="hashes", seed=12345) >>> model = mh.fit(df) >>> model.transform(df).head() - Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([-1638925... + Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668... >>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),), ... (4, Vectors.sparse(6, [2, 3, 5], [1.0, 1.0, 1.0]),), ... (5, Vectors.sparse(6, [1, 2, 4], [1.0, 1.0, 1.0]),)] >>> df2 = spark.createDataFrame(data2, ["id", "features"]) >>> key = Vectors.sparse(6, [1, 2], [1.0, 1.0]) >>> model.approxNearestNeighbors(df2, key, 1).collect() - [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([-163892... + [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([6179668... >>> model.approxSimilarityJoin(df, df2, 0.6, distCol="JaccardDistance").select( ... col("datasetA.id").alias("idA"), ... col("datasetB.id").alias("idB"), @@ -1309,8 +1309,8 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed, +---+---+---------------+ |idA|idB|JaccardDistance| +---+---+---------------+ - | 1| 4| 0.5| | 0| 5| 0.5| + | 1| 4| 0.5| +---+---+---------------+ ... >>> mhPath = temp_path + "/mh" @@ -2582,25 +2582,31 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl typeConverter=TypeConverters.toListString) caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + "comparison over the stop words", typeConverter=TypeConverters.toBoolean) + locale = Param(Params._dummy(), "locale", "locale of the input. ignored when case sensitive " + + "is true", typeConverter=TypeConverters.toString) @keyword_only - def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): + def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False, + locale=None): """ - __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) + __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ + locale=None) """ super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"), - caseSensitive=False) + caseSensitive=False, locale=self._java_obj.getLocale()) kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") - def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): + def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False, + locale=None): """ - setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) + setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ + locale=None) Sets params for this StopWordRemover. """ kwargs = self._input_kwargs @@ -2634,6 +2640,20 @@ def getCaseSensitive(self): """ return self.getOrDefault(self.caseSensitive) + @since("2.4.0") + def setLocale(self, value): + """ + Sets the value of :py:attr:`locale`. + """ + return self._set(locale=value) + + @since("2.4.0") + def getLocale(self): + """ + Gets the value of :py:attr:`locale`. + """ + return self.getOrDefault(self.locale) + @staticmethod @since("2.0.0") def loadDefaultStopWords(language): diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index b8dafd49d354d..fd19fd96c4df6 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -16,8 +16,9 @@ # from pyspark import keyword_only, since +from pyspark.sql import DataFrame from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, _jvm from pyspark.ml.param.shared import * __all__ = ["FPGrowth", "FPGrowthModel"] @@ -243,3 +244,104 @@ def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", def _create_model(self, java_model): return FPGrowthModel(java_model) + + +class PrefixSpan(JavaParams): + """ + .. note:: Experimental + + A parallel PrefixSpan algorithm to mine frequent sequential patterns. + The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + Efficiently by Prefix-Projected Pattern Growth + (see here). + This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns` + method to run the PrefixSpan algorithm. + + @see Sequential Pattern Mining + (Wikipedia) + .. versionadded:: 2.4.0 + + """ + + minSupport = Param(Params._dummy(), "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.", + typeConverter=TypeConverters.toFloat) + + maxPatternLength = Param(Params._dummy(), "maxPatternLength", + "The maximal length of the sequential pattern. Must be > 0.", + typeConverter=TypeConverters.toInt) + + maxLocalProjDBSize = Param(Params._dummy(), "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the " + + "internal storage format) allowed in a projected database before " + + "local processing. If a projected database exceeds this size, " + + "another iteration of distributed prefix growth is run. " + + "Must be > 0.", + typeConverter=TypeConverters.toInt) + + sequenceCol = Param(Params._dummy(), "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.", + typeConverter=TypeConverters.toString) + + @keyword_only + def __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + super(PrefixSpan, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid) + self._setDefault(minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.4.0") + def setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.4.0") + def findFrequentSequentialPatterns(self, dataset): + """ + .. note:: Experimental + Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + + :param dataset: A dataframe containing a sequence column which is + `ArrayType(ArrayType(T))` type, T is the item type for the input dataset. + :return: A `DataFrame` that contains columns of sequence and corresponding frequency. + The schema of it will be: + - `sequence: ArrayType(ArrayType(T))` (T is the item type) + - `freq: Long` + + >>> from pyspark.ml.fpm import PrefixSpan + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(sequence=[[1, 2], [3]]), + ... Row(sequence=[[1], [3, 2], [1, 2]]), + ... Row(sequence=[[1, 2], [5]]), + ... Row(sequence=[[6]])]).toDF() + >>> prefixSpan = PrefixSpan(minSupport=0.5, maxPatternLength=5) + >>> prefixSpan.findFrequentSequentialPatterns(df).sort("sequence").show(truncate=False) + +----------+----+ + |sequence |freq| + +----------+----+ + |[[1]] |3 | + |[[1], [3]]|2 | + |[[1, 2]] |3 | + |[[2]] |3 | + |[[3]] |2 | + +----------+----+ + + .. versionadded:: 2.4.0 + """ + self._transfer_params_to_java() + jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf) + return DataFrame(jdf, dataset.sql_ctx) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 96d702f844839..5f0c57ee3cc67 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -31,6 +31,8 @@ from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string from pyspark.sql import DataFrame, SparkSession +__all__ = ["ImageSchema"] + class _ImageSchema(object): """ diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index 6a611a2b5b59d..2548fd0f50b33 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -1156,6 +1156,11 @@ def sparse(numRows, numCols, colPtrs, rowIndices, values): def _test(): import doctest + try: + # Numpy 1.14+ changed it's string format. + np.set_printoptions(legacy='1.13') + except TypeError: + pass (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: sys.exit(-1) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 6e9e0a34cdfde..e45ba840b412b 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -162,7 +162,9 @@ def get$Name(self): "fitting. If set to true, then all sub-models will be available. Warning: For large " + "models, collecting all sub-models can cause OOMs on the Spark driver.", "False", "TypeConverters.toBoolean"), - ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")] + ("loss", "the loss function to be optimized.", None, "TypeConverters.toString"), + ("distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", + "'euclidean'", "TypeConverters.toString")] code = [] for name, doc, defaultValueStr, typeConverter in shared: diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 08408ee8fbfcc..618f5bf0a8103 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -790,3 +790,27 @@ def getCacheNodeIds(self): """ return self.getOrDefault(self.cacheNodeIds) + +class HasDistanceMeasure(Params): + """ + Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'. + """ + + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString) + + def __init__(self): + super(HasDistanceMeasure, self).__init__() + self._setDefault(distanceMeasure='euclidean') + + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + def getDistanceMeasure(self): + """ + Gets the value of distanceMeasure or its default value. + """ + return self.getOrDefault(self.distanceMeasure) + diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9a66d87d7f211..83f0edb397271 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -95,6 +95,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction True >>> model.numFeatures 1 + >>> model.write().format("pmml").save(model_path + "_2") .. versionadded:: 1.4.0 """ @@ -161,7 +162,7 @@ def getEpsilon(self): return self.getOrDefault(self.epsilon) -class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): +class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable): """ Model fitted by :class:`LinearRegression`. @@ -602,6 +603,14 @@ class TreeEnsembleParams(DecisionTreeParams): "used for learning each decision tree, in range (0, 1].", typeConverter=TypeConverters.toFloat) + supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] + + featureSubsetStrategy = \ + Param(Params._dummy(), "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", + typeConverter=TypeConverters.toString) + def __init__(self): super(TreeEnsembleParams, self).__init__() @@ -619,6 +628,22 @@ def getSubsamplingRate(self): """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + + .. note:: Deprecated in 2.4.0 and will be removed in 3.0.0. + """ + return self._set(featureSubsetStrategy=value) + + @since("1.4.0") + def getFeatureSubsetStrategy(self): + """ + Gets the value of featureSubsetStrategy or its default value. + """ + return self.getOrDefault(self.featureSubsetStrategy) + class TreeRegressorParams(Params): """ @@ -654,14 +679,8 @@ class RandomForestParams(TreeEnsembleParams): Private class to track supported random forest parameters. """ - supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).", typeConverter=TypeConverters.toInt) - featureSubsetStrategy = \ - Param(Params._dummy(), "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", - typeConverter=TypeConverters.toString) def __init__(self): super(RandomForestParams, self).__init__() @@ -680,20 +699,6 @@ def getNumTrees(self): """ return self.getOrDefault(self.numTrees) - @since("1.4.0") - def setFeatureSubsetStrategy(self, value): - """ - Sets the value of :py:attr:`featureSubsetStrategy`. - """ - return self._set(featureSubsetStrategy=value) - - @since("1.4.0") - def getFeatureSubsetStrategy(self): - """ - Gets the value of featureSubsetStrategy or its default value. - """ - return self.getOrDefault(self.featureSubsetStrategy) - class GBTParams(TreeEnsembleParams): """ @@ -981,6 +986,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): @@ -1029,6 +1041,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42) >>> print(gbt.getImpurity()) variance + >>> print(gbt.getFeatureSubsetStrategy()) + all >>> model = gbt.fit(df) >>> model.featureImportances SparseVector(1, {0: 1.0}) @@ -1056,6 +1070,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))], + ... ["label", "features"]) + >>> model.evaluateEachIteration(validation, "squared") + [0.0, 0.0, 0.0, 0.0, 0.0] .. versionadded:: 1.4.0 """ @@ -1075,20 +1093,20 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impurity="variance"): + impurity="variance", featureSubsetStrategy="all"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance") + impurity="variance", featureSubsetStrategy="all") """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, - impurity="variance") + impurity="variance", featureSubsetStrategy="all") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1098,13 +1116,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impuriy="variance"): + impuriy="variance", featureSubsetStrategy="all"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance") + impurity="variance", featureSubsetStrategy="all") Sets params for Gradient Boosted Tree Regression. """ kwargs = self._input_kwargs @@ -1127,6 +1145,13 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): """ @@ -1156,6 +1181,20 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset, loss): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + :param loss: + The loss function used to compute error. + Supported options: squared, absolute + """ + return self._call_java("evaluateEachIteration", dataset, loss) + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index a06ab31a7a56a..370154fc6d62a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -388,8 +388,14 @@ def summary(self, featuresCol, weightCol=None): if __name__ == "__main__": import doctest + import numpy import pyspark.ml.stat from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.ml.stat.__dict__.copy() # The small batch size here ensures that we see multiple batches, diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2ec0be60e9fa9..bc782138292bf 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -681,6 +681,13 @@ def test_stopwordsremover(self): self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, []) + # with locale + stopwords = ["BELKİ"] + dataset = self.spark.createDataFrame([Row(input=["belki"])]) + stopWordRemover.setStopWords(stopwords).setLocale("tr") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) def test_count_vectorizer_with_binary(self): dataset = self.spark.createDataFrame([ @@ -1355,6 +1362,23 @@ def test_linear_regression(self): except OSError: pass + def test_linear_regression_pmml_basic(self): + # Most of the validation is done in the Scala side, here we just check + # that we output text rather than parquet (e.g. that the format flag + # was respected). + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1) + model = lr.fit(df) + path = tempfile.mkdtemp() + lr_path = path + "/lr-pmml" + model.write().format("pmml").save(lr_path) + pmml_text_list = self.sc.textFile(lr_path).collect() + pmml_text = "\n".join(pmml_text_list) + self.assertIn("Apache Spark", pmml_text) + self.assertIn("PMML", pmml_text) + def test_logistic_regression(self): lr = LogisticRegression(maxIter=1) path = tempfile.mkdtemp() @@ -1595,6 +1619,44 @@ def test_default_read_write(self): self.assertEqual(lr.uid, lr3.uid) self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) + def test_default_read_write_default_params(self): + lr = LogisticRegression() + self.assertFalse(lr.isSet(lr.getParam("threshold"))) + + lr.setMaxIter(50) + lr.setThreshold(.75) + + # `threshold` is set by user, default param `predictionCol` is not set by user. + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + writer = DefaultParamsWriter(lr) + metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) + self.assertTrue("defaultParamMap" in metadata) + + reader = DefaultParamsReadable.read() + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + # manually create metadata without `defaultParamMap` section. + del metadata['defaultParamMap'] + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): + reader.getAndSetParams(lr, loadedMetadata) + + # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. + metadata['sparkVersion'] = '2.3.0' + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + class LDATest(SparkSessionTestCase): @@ -2136,17 +2198,23 @@ class ImageReaderTest2(PySparkTestCase): @classmethod def setUpClass(cls): super(ImageReaderTest2, cls).setUpClass() + cls.hive_available = True # Note that here we enable Hive's support. cls.spark = None try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") - cls.spark = HiveContext._createForTesting(cls.sc) + cls.hive_available = False + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -2662,6 +2730,6 @@ def testDefaultFitMultiple(self): if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index a486c6a3fdeb5..e846834761e49 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -30,6 +30,7 @@ from pyspark import SparkContext, since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession +from pyspark.util import VersionUtils def _jvm(): @@ -62,7 +63,7 @@ def _randomUID(cls): Generate a unique unicode id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. """ - return unicode(cls.__name__ + "_" + uuid.uuid4().hex[12:]) + return unicode(cls.__name__ + "_" + uuid.uuid4().hex[-12:]) @inherit_doc @@ -147,6 +148,23 @@ def overwrite(self): return self +@inherit_doc +class GeneralMLWriter(MLWriter): + """ + Utility class that can save ML instances in different formats. + + .. versionadded:: 2.4.0 + """ + + def format(self, source): + """ + Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class + name for export). + """ + self.source = source + return self + + @inherit_doc class JavaMLWriter(MLWriter): """ @@ -191,6 +209,24 @@ def session(self, sparkSession): return self +@inherit_doc +class GeneralJavaMLWriter(JavaMLWriter): + """ + (Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types + """ + + def __init__(self, instance): + super(GeneralJavaMLWriter, self).__init__(instance) + + def format(self, source): + """ + Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class + name for export). + """ + self._jwrite.format(source) + return self + + @inherit_doc class MLWritable(object): """ @@ -219,6 +255,17 @@ def write(self): return JavaMLWriter(self) +@inherit_doc +class GeneralJavaMLWritable(JavaMLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`. + """ + + def write(self): + """Returns an GeneralMLWriter instance for this ML instance.""" + return GeneralJavaMLWriter(self) + + @inherit_doc class MLReader(BaseReadWrite): """ @@ -396,6 +443,7 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): - sparkVersion - uid - paramMap + - defaultParamMap (since 2.4.0) - (optionally, extra metadata) :param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc. :param paramMap: If given, this is saved in the "paramMap" field. @@ -417,15 +465,24 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): """ uid = instance.uid cls = instance.__module__ + '.' + instance.__class__.__name__ - params = instance.extractParamMap() + + # User-supplied param values + params = instance._paramMap jsonParams = {} if paramMap is not None: jsonParams = paramMap else: for p in params: jsonParams[p.name] = params[p] + + # Default param values + jsonDefaultParams = {} + for p in instance._defaultParamMap: + jsonDefaultParams[p.name] = instance._defaultParamMap[p] + basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), - "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams} + "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams, + "defaultParamMap": jsonDefaultParams} if extraMetadata is not None: basicMetadata.update(extraMetadata) return json.dumps(basicMetadata, separators=[',', ':']) @@ -523,11 +580,26 @@ def getAndSetParams(instance, metadata): """ Extract Params from metadata, and set them in the instance. """ + # Set user-supplied param values for paramName in metadata['paramMap']: param = instance.getParam(paramName) paramValue = metadata['paramMap'][paramName] instance.set(param, paramValue) + # Set default param values + majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion']) + major = majorAndMinorVersions[0] + minor = majorAndMinorVersions[1] + + # For metadata file prior to Spark 2.4, there is no default section. + if major > 2 or (major == 2 and minor >= 4): + assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \ + "`defaultParamMap` section not found" + + for paramName in metadata['defaultParamMap']: + paramValue = metadata['defaultParamMap'][paramName] + instance._setDefault(**{paramName: paramValue}) + @staticmethod def loadParamsInstance(path, sc): """ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index bb281981fd56b..e00ed95ef0701 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -258,6 +258,9 @@ def load(cls, sc, path): model.setThreshold(threshold) return model + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionWithSGD(object): """ diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 0cbabab13a896..b09469b9f5c2d 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -1042,7 +1042,13 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, def _test(): import doctest + import numpy import pyspark.mllib.clustering + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.clustering.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 36cb03369b8c0..6c65da58e4e2b 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -532,8 +532,14 @@ def accuracy(self): def _test(): import doctest + import numpy from pyspark.sql import SparkSession import pyspark.mllib.evaluation + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.evaluation.__dict__.copy() spark = SparkSession.builder\ .master("local[4]")\ diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 60d96d8d5ceb8..4afd6666400b0 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1368,6 +1368,12 @@ def R(self): def _test(): import doctest + import numpy + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: sys.exit(-1) diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index bba88542167ad..7e8b15056cabe 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -1364,9 +1364,15 @@ def toCoordinateMatrix(self): def _test(): import doctest + import numpy from pyspark.sql import SparkSession from pyspark.mllib.linalg import Matrices import pyspark.mllib.linalg.distributed + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.linalg.distributed.__dict__.copy() spark = SparkSession.builder\ .master("local[2]")\ diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 3c75b132ecad2..937bb154c2356 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -303,7 +303,13 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): def _test(): import doctest + import numpy from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = globals().copy() spark = SparkSession.builder\ .master("local[4]")\ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1037bab7f1088..4c2ce137e331c 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -57,6 +57,7 @@ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.fpm import FPGrowth from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs @@ -1762,14 +1763,25 @@ def test_pca(self): self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) +class FPGrowthTest(MLlibTestCase): + + def test_fpgrowth(self): + data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + rdd = self.sc.parallelize(data, 2) + model1 = FPGrowth.train(rdd, 0.6, 2) + # use default data partition number when numPartitions is not specified + model2 = FPGrowth.train(rdd, 0.6) + self.assertEqual(sorted(model1.freqItemsets().collect()), + sorted(model2.freqItemsets().collect())) + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") sc.stop() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4b44f76747264..951851804b1d8 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,9 +39,11 @@ else: from itertools import imap as map, ifilter as filter +from pyspark.java_gateway import do_server_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer + PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \ + UTF8Deserializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -51,6 +53,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync +from pyspark.util import fail_on_stopiteration __all__ = ["RDD"] @@ -71,6 +74,7 @@ class PythonEvalType(object): SQL_SCALAR_PANDAS_UDF = 200 SQL_GROUPED_MAP_PANDAS_UDF = 201 SQL_GROUPED_AGG_PANDAS_UDF = 202 + SQL_WINDOW_AGG_PANDAS_UDF = 203 def portable_hash(x): @@ -136,7 +140,8 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) -def _load_from_socket(port, serializer): +def _load_from_socket(sock_info, serializer): + port, auth_secret = sock_info sock = None # Support for both IPv4 and IPv6. # On most of IPv6-ready systems, IPv6 will take precedence. @@ -156,8 +161,12 @@ def _load_from_socket(port, serializer): # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) + + sockfile = sock.makefile("rwb", 65536) + do_server_auth(sockfile, auth_secret) + # The socket will be automatically closed when garbage-collected. - return serializer.load_stream(sock.makefile("rb", 65536)) + return serializer.load_stream(sockfile) def ignore_unicode_prefix(f): @@ -332,7 +341,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(f, iterator) + return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -347,7 +356,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(f, iterator)) + return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -410,7 +419,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(f, iterator) + return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -791,6 +800,8 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ + f = fail_on_stopiteration(f) + def processPartition(iterator): for x in iterator: f(x) @@ -822,8 +833,8 @@ def collect(self): to be small, as all the data is loaded into the driver's memory. """ with SCCallSiteSync(self.context) as css: - port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) - return list(_load_from_socket(port, self._jrdd_deserializer)) + sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) + return list(_load_from_socket(sock_info, self._jrdd_deserializer)) def reduce(self, f): """ @@ -840,6 +851,8 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ + f = fail_on_stopiteration(f) + def func(iterator): iterator = iter(iterator) try: @@ -911,6 +924,8 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ + op = fail_on_stopiteration(op) + def func(iterator): acc = zeroValue for obj in iterator: @@ -943,6 +958,9 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ + seqOp = fail_on_stopiteration(seqOp) + combOp = fail_on_stopiteration(combOp) + def func(iterator): acc = zeroValue for obj in iterator: @@ -1352,7 +1370,10 @@ def takeUpToNumLeft(iterator): iterator = iter(iterator) taken = 0 while taken < left: - yield next(iterator) + try: + yield next(iterator) + except StopIteration: + return taken += 1 p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) @@ -1636,6 +1657,8 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ + func = fail_on_stopiteration(func) + def reducePartition(iterator): m = {} for k, v in iterator: @@ -2380,8 +2403,8 @@ def toLocalIterator(self): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ with SCCallSiteSync(self.context) as css: - port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) - return _load_from_socket(port, self._jrdd_deserializer) + sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) + return _load_from_socket(sock_info, self._jrdd_deserializer) def _prepare_for_python_RDD(sc, command): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 15753f77bd903..4c16b5fc26f3d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,8 +33,9 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -PySpark serialize objects in batches; By default, the batch size is chosen based -on the size of objects, also configurable by SparkContext's C{batchSize} parameter: +PySpark serializes objects in batches; by default, the batch size is chosen based +on the size of objects and is also configurable by SparkContext's C{batchSize} +parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -100,7 +101,7 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): """ Return an iterator of deserialized batches (iterable) of objects from the input stream. - if the serializer does not operate on batches the default implementation returns an + If the serializer does not operate on batches the default implementation returns an iterator of single element lists. """ return map(lambda x: [x], self.load_stream(stream)) @@ -461,7 +462,7 @@ def dumps(self, obj): return obj -# Hook namedtuple, make it picklable +# Hack namedtuple, make it picklable __cls = {} @@ -525,15 +526,15 @@ def namedtuple(*args, **kwargs): cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) - # replace namedtuple with new one + # replace namedtuple with the new one collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple collections.namedtuple.__code__ = namedtuple.__code__ collections.namedtuple.__hijack = 1 - # hack the cls already generated by namedtuple - # those created in other module can be pickled as normal, + # hack the cls already generated by namedtuple. + # Those created in other modules can be pickled as normal, # so only hack those in __main__ module for n, o in sys.modules["__main__"].__dict__.items(): if (type(o) is type and o.__base__ is tuple @@ -627,7 +628,7 @@ def loads(self, obj): elif _type == b'P': return pickle.loads(obj[1:]) else: - raise ValueError("invalid sevialization type: %s" % _type) + raise ValueError("invalid serialization type: %s" % _type) class CompressedSerializer(FramedSerializer): diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index b5fcf7092d93a..472c3cd4452f0 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -38,25 +38,13 @@ SparkContext._ensure_initialized() try: - # Try to access HiveConf, it will raise exception if Hive is not added - conf = SparkConf() - if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': - SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() - spark = SparkSession.builder\ - .enableHiveSupport()\ - .getOrCreate() - else: - spark = SparkSession.builder.getOrCreate() -except py4j.protocol.Py4JError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - spark = SparkSession.builder.getOrCreate() -except TypeError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - spark = SparkSession.builder.getOrCreate() + spark = SparkSession._create_shell_session() +except Exception: + import sys + import traceback + warnings.warn("Failed to initialize Spark session.") + traceback.print_exc(file=sys.stderr) + sys.exit(1) sc = spark.sparkContext sql = spark.sql diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 02c773302e9da..bd0ac0039ffe1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,6 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer +from pyspark.util import fail_on_stopiteration try: @@ -94,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = createCombiner - self.mergeValue = mergeValue - self.mergeCombiners = mergeCombiners + self.createCombiner = fail_on_stopiteration(createCombiner) + self.mergeValue = fail_on_stopiteration(mergeValue) + self.mergeCombiners = fail_on_stopiteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index db49040e17b63..f80bf598c2211 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -63,6 +63,14 @@ def _checkType(self, obj, identifier): raise TypeError("expected %s '%s' to be a string (was '%s')" % (identifier, obj, type(obj).__name__)) + @ignore_unicode_prefix + @since(2.4) + def isModifiable(self, key): + """Indicates whether the configuration property with the given key + is modifiable in the current session. + """ + return self._jconf.isModifiable(key) + def _test(): import os diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index e9ec7ba866761..9c094dd9a9033 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -93,6 +93,11 @@ def _ssql_ctx(self): """ return self._jsqlContext + @property + def _conf(self): + """Accessor for the JVM SQL-specific configurations""" + return self.sparkSession._jsparkSession.sessionState().conf() + @classmethod @since(1.6) def getOrCreate(cls, sc): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 16f8e52dead7b..c40aea9bcef0a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -78,6 +78,9 @@ def __init__(self, jdf, sql_ctx): self.is_cached = False self._schema = None # initialized lazily self._lazy_rdd = None + # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice + # by __repr__ and _repr_html_ while eager evaluation opened. + self._support_repr_html = False @property @since(1.3) @@ -352,7 +355,46 @@ def show(self, n=20, truncate=True, vertical=False): print(self._jdf.showString(n, int(truncate), vertical)) def __repr__(self): - return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + if not self._support_repr_html and self.sql_ctx._conf.isReplEagerEvalEnabled(): + vertical = False + return self._jdf.showString( + self.sql_ctx._conf.replEagerEvalMaxNumRows(), + self.sql_ctx._conf.replEagerEvalTruncate(), vertical) + else: + return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + + def _repr_html_(self): + """Returns a dataframe with html code when you enabled eager evaluation + by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are + using support eager evaluation with HTML. + """ + import cgi + if not self._support_repr_html: + self._support_repr_html = True + if self.sql_ctx._conf.isReplEagerEvalEnabled(): + max_num_rows = max(self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0) + sock_info = self._jdf.getRowsToPython( + max_num_rows, self.sql_ctx._conf.replEagerEvalTruncate()) + rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) + head = rows[0] + row_data = rows[1:] + has_more_data = len(row_data) > max_num_rows + row_data = row_data[:max_num_rows] + + html = "\n" + # generate table head + html += "\n" % "\n" % "
      %s
      ".join(map(lambda x: cgi.escape(x), head)) + # generate table rows + for row in row_data: + html += "
      %s
      ".join( + map(lambda x: cgi.escape(x), row)) + html += "
      \n" + if has_more_data: + html += "only showing top %d %s\n" % ( + max_num_rows, "row" if max_num_rows == 1 else "rows") + return html + else: + return None @since(2.1) def checkpoint(self, eager=True): @@ -463,8 +505,8 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectToPython() - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + sock_info = self._jdf.collectToPython() + return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @since(2.0) @@ -477,8 +519,8 @@ def toLocalIterator(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.toPythonIterator() - return _load_from_socket(port, BatchedSerializer(PickleSerializer())) + sock_info = self._jdf.toPythonIterator() + return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer())) @ignore_unicode_prefix @since(1.3) @@ -1975,6 +2017,8 @@ def toPandas(self): .. note:: This method should only be used if the resulting Pandas's DataFrame is expected to be small, as all the data is loaded into the driver's memory. + .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental. + >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice @@ -1985,13 +2029,12 @@ def toPandas(self): import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ - == "true": - timezone = self.sql_ctx.getConf("spark.sql.session.timeZone") + if self.sql_ctx._conf.pandasRespectSessionTimeZone(): + timezone = self.sql_ctx._conf.sessionLocalTimeZone() else: timezone = None - if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + if self.sql_ctx._conf.arrowEnabled(): use_arrow = True try: from pyspark.sql.types import to_arrow_schema @@ -2001,8 +2044,7 @@ def toPandas(self): to_arrow_schema(self.schema) except Exception as e: - if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ - .lower() == "true": + if self.sql_ctx._conf.arrowFallbackEnabled(): msg = ( "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " @@ -2087,8 +2129,8 @@ def _collectAsArrow(self): .. note:: Experimental. """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(port, ArrowSerializer())) + sock_info = self._jdf.collectAsArrowToPython() + return list(_load_from_socket(sock_info, ArrowSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index da32ab25cad0c..5ef73987a66a6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -152,6 +152,9 @@ def _(): _collect_list_doc = """ Aggregate function: returns a list of objects with duplicates. + .. note:: The function is non-deterministic because the order of collected results depends + on order of rows which may be non-deterministic after a shuffle. + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) >>> df2.agg(collect_list('age')).collect() [Row(collect_list(age)=[2, 5, 5])] @@ -159,6 +162,9 @@ def _(): _collect_set_doc = """ Aggregate function: returns a set of objects with duplicate elements eliminated. + .. note:: The function is non-deterministic because the order of collected results depends + on order of rows which may be non-deterministic after a shuffle. + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) >>> df2.agg(collect_set('age')).collect() [Row(collect_set(age)=[5, 2])] @@ -401,6 +407,9 @@ def first(col, ignorenulls=False): The function by default returns the first values it sees. It will return the first non-null value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + + .. note:: The function is non-deterministic because its results depends on order of rows which + may be non-deterministic after a shuffle. """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls) @@ -489,6 +498,9 @@ def last(col, ignorenulls=False): The function by default returns the last values it sees. It will return the last non-null value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + + .. note:: The function is non-deterministic because its results depends on order of rows + which may be non-deterministic after a shuffle. """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls) @@ -504,6 +516,8 @@ def monotonically_increasing_id(): within each partition in the lower 33 bits. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. + .. note:: The function is non-deterministic because its result depends on partition IDs. + As an example, consider a :class:`DataFrame` with two partitions, each with 3 records. This expression would return the following IDs: 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. @@ -536,6 +550,8 @@ def rand(seed=None): """Generates a random column with independent and identically distributed (i.i.d.) samples from U[0.0, 1.0]. + .. note:: The function is non-deterministic in general case. + >>> df.withColumn('rand', rand(seed=42) * 3).collect() [Row(age=2, name=u'Alice', rand=1.1568609015300986), Row(age=5, name=u'Bob', rand=1.403379671529166)] @@ -554,6 +570,8 @@ def randn(seed=None): """Generates a column with independent and identically distributed (i.i.d.) samples from the standard normal distribution. + .. note:: The function is non-deterministic in general case. + >>> df.withColumn('randn', randn(seed=42)).collect() [Row(age=2, name=u'Alice', randn=-0.7556247885860078), Row(age=5, name=u'Bob', randn=-0.0861619008451133)] @@ -1088,16 +1106,23 @@ def add_months(start, months): @since(1.5) -def months_between(date1, date2): +def months_between(date1, date2, roundOff=True): """ - Returns the number of months between date1 and date2. + Returns number of months between dates date1 and date2. + If date1 is later than date2, then the result is positive. + If date1 and date2 are on the same day of month, or both are the last day of month, + returns an integer (time of day will be ignored). + The result is rounded off to 8 digits unless `roundOff` is set to `False`. >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() - [Row(months=3.9495967...)] + [Row(months=3.94959677)] + >>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect() + [Row(months=3.9495967741935485)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) + return Column(sc._jvm.functions.months_between( + _to_java_column(date1), _to_java_column(date2), roundOff)) @since(2.2) @@ -1260,11 +1285,21 @@ def from_utc_timestamp(timestamp, tz): that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 03:40:00.0'. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(from_utc_timestamp(df.t, "PST").alias('local_time')).collect() + :param timestamp: the column that contains timestamps + :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc + + .. versionchanged:: 2.4 + `tz` can take a :class:`Column` containing timezone ID strings. + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) + >>> df.select(from_utc_timestamp(df.ts, "PST").alias('local_time')).collect() [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))] + >>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect() + [Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))] """ sc = SparkContext._active_spark_context + if isinstance(tz, Column): + tz = _to_java_column(tz) return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1275,11 +1310,21 @@ def to_utc_timestamp(timestamp, tz): zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield '2017-07-14 01:40:00.0'. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['ts']) + :param timestamp: the column that contains timestamps + :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc + + .. versionchanged:: 2.4 + `tz` can take a :class:`Column` containing timezone ID strings. + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect() [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))] + >>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect() + [Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))] """ sc = SparkContext._active_spark_context + if isinstance(tz, Column): + tz = _to_java_column(tz) return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1794,6 +1839,25 @@ def create_map(*cols): return Column(jc) +@since(2.4) +def map_from_arrays(col1, col2): + """Creates a new map from two arrays. + + :param col1: name of column containing a set of keys. All elements should not be null + :param col2: name of column containing a set of values + + >>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v']) + >>> df.select(map_from_arrays(df.k, df.v).alias("map")).show() + +----------------+ + | map| + +----------------+ + |[2 -> a, 5 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def array(*cols): """Creates a new array column. @@ -1830,6 +1894,55 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def arrays_overlap(a1, a2): + """ + Collection function: returns true if the arrays contain any common non-null element; if not, + returns null if both the arrays are non-empty and any of them contains a null element; returns + false otherwise. + + >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() + [Row(overlap=True), Row(overlap=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2))) + + +@since(2.4) +def slice(x, start, length): + """ + Collection function: returns an array containing all the elements in `x` from index `start` + (or starting from the end if `start` is negative) with the specified `length`. + >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() + [Row(sliced=[2, 3]), Row(sliced=[5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.slice(_to_java_column(x), start, length)) + + +@ignore_unicode_prefix +@since(2.4) +def array_join(col, delimiter, null_replacement=None): + """ + Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + `null_replacement` if set, otherwise they are ignored. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + >>> df.select(array_join(df.data, ",").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a')] + >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a,NULL')] + """ + sc = SparkContext._active_spark_context + if null_replacement is None: + return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter)) + else: + return Column(sc._jvm.functions.array_join( + _to_java_column(col), delimiter, null_replacement)) + + @since(1.5) @ignore_unicode_prefix def concat(*cols): @@ -1890,6 +2003,55 @@ def element_at(col, extraction): return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) +@since(2.4) +def array_remove(col, element): + """ + Collection function: Remove all elements that equal to element from the given array. + + :param col: name of column containing array + :param element: element to be removed from the array + + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + >>> df.select(array_remove(df.data, 1)).collect() + [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) + + +@since(2.4) +def array_distinct(col): + """ + Collection function: removes duplicate values from the array. + :param col: name of column or expression + + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df.select(array_distinct(df.data)).collect() + [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(2.4) +def array_union(col1, col2): + """ + Collection function: returns an array of the elements in the union of col1 and col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_union(df.c1, df.c2)).collect() + [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. @@ -2036,12 +2198,13 @@ def json_tuple(col, *fields): return Column(jc) +@ignore_unicode_prefix @since(2.1) def from_json(col, schema, options={}): """ - Parses a column containing a JSON string into a :class:`StructType` or :class:`ArrayType` - of :class:`StructType`\\s with the specified schema. Returns `null`, in the case of an - unparseable string. + Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType` + as keys type, :class:`StructType` or :class:`ArrayType` of :class:`StructType`\\s with + the specified schema. Returns `null`, in the case of an unparseable string. :param col: string column in json format :param schema: a StructType or ArrayType of StructType to use when parsing the json column. @@ -2058,16 +2221,23 @@ def from_json(col, schema, options={}): [Row(json=Row(a=1))] >>> df.select(from_json(df.value, "a INT").alias("json")).collect() [Row(json=Row(a=1))] + >>> df.select(from_json(df.value, "MAP").alias("json")).collect() + [Row(json={u'a': 1})] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=[Row(a=1)])] + >>> schema = schema_of_json(lit('''{"a": 0}''')) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=Row(a=1))] """ sc = SparkContext._active_spark_context if isinstance(schema, DataType): schema = schema.json() + elif isinstance(schema, Column): + schema = _to_java_column(schema) jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options) return Column(jc) @@ -2109,6 +2279,28 @@ def to_json(col, options={}): return Column(jc) +@ignore_unicode_prefix +@since(2.4) +def schema_of_json(col): + """ + Parses a column containing a JSON string and infers its schema in DDL format. + + :param col: string column in json format + + >>> from pyspark.sql.types import * + >>> data = [(1, '{"a": 1}')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(schema_of_json(df.value).alias("json")).collect() + [Row(json=u'struct')] + >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect() + [Row(json=u'struct')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.schema_of_json(_to_java_column(col)) + return Column(jc) + + @since(1.5) def size(col): """ @@ -2158,20 +2350,38 @@ def array_max(col): def sort_array(col, asc=True): """ Collection function: sorts the input array in ascending or descending order according - to the natural ordering of the array elements. + to the natural ordering of the array elements. Null elements will be placed at the beginning + of the returned array in ascending order or at the end of the returned array in descending + order. :param col: name of column or expression - >>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) >>> df.select(sort_array(df.data).alias('r')).collect() - [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] + [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() - [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] + [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(2.4) +def array_sort(col): + """ + Collection function: sorts the input array in ascending order. The elements of the input array + must be orderable. Null elements will be placed at the end of the returned array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) + >>> df.select(array_sort(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_sort(_to_java_column(col))) + + @since(1.5) @ignore_unicode_prefix def reverse(col): @@ -2191,6 +2401,23 @@ def reverse(col): return Column(sc._jvm.functions.reverse(_to_java_column(col))) +@since(2.4) +def flatten(col): + """ + Collection function: creates a single array from an array of arrays. + If a structure of nested arrays is deeper than two levels, + only one level of nesting is removed. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + >>> df.select(flatten(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.flatten(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ @@ -2231,6 +2458,99 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) +@since(2.4) +def map_entries(col): + """ + Collection function: Returns an unordered array of all entries in the given map. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_entries + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_entries("data").alias("entries")).show() + +----------------+ + | entries| + +----------------+ + |[[1, a], [2, b]]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_entries(_to_java_column(col))) + + +@since(2.4) +def map_from_entries(col): + """ + Collection function: Returns a map created from the given array of entries. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_from_entries + >>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data") + >>> df.select(map_from_entries("data").alias("map")).show() + +----------------+ + | map| + +----------------+ + |[1 -> a, 2 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_entries(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(2.4) +def array_repeat(col, count): + """ + Collection function: creates an array containing a column repeated count times. + + >>> df = spark.createDataFrame([('ab',)], ['data']) + >>> df.select(array_repeat(df.data, 3).alias('r')).collect() + [Row(r=[u'ab', u'ab', u'ab'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count)) + + +@since(2.4) +def arrays_zip(*cols): + """ + Collection function: Returns a merged array of structs in which the N-th struct contains all + N-th values of input arrays. + + :param cols: columns of arrays to be merged. + + >>> from pyspark.sql.functions import arrays_zip + >>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2']) + >>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect() + [Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3, vals2=4)])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) + + +@since(2.4) +def map_concat(*cols): + """Returns the union of all the given maps. + + :param cols: list of column names (string) or list of :class:`Column` expressions + + >>> from pyspark.sql.functions import map_concat + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2") + >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) + +--------------------------------+ + |map3 | + +--------------------------------+ + |[1 -> a, 2 -> b, 3 -> c, 1 -> d]| + +--------------------------------+ + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): @@ -2309,6 +2629,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. Default: SCALAR. + .. note:: Experimental + The function type of the UDF can be one of the following: 1. SCALAR @@ -2350,7 +2672,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned - `pandas.DataFrame`. + `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match + the field names in the defined returnType schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. The length of the returned `pandas.DataFrame` can be arbitrary. Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. @@ -2399,6 +2723,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2|6.0| +---+---+ + .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is + recommended to explicitly index the columns by name to ensure the positions are correct, + or alternatively use an `OrderedDict`. + For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or + `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` 3. GROUPED_AGG @@ -2408,10 +2738,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): The returned scalar can be either a python primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. - :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as - output types. + :class:`MapType` and :class:`StructType` are currently not supported as output types. + + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and + :class:`pyspark.sql.Window` - Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` + This example shows using grouped aggregated UDFs with groupby: >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( @@ -2428,7 +2760,31 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 6.0| +---+-----------+ - .. seealso:: :meth:`pyspark.sql.GroupedData.agg` + This example shows using grouped aggregated UDFs as window functions. Note that only + unbounded window frame is supported at the moment: + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> w = Window.partitionBy('id') \\ + ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP + +---+----+------+ + | id| v|mean_v| + +---+----+------+ + | 1| 1.0| 1.5| + | 1| 2.0| 1.5| + | 2| 3.0| 6.0| + | 2| 5.0| 6.0| + | 2|10.0| 6.0| + +---+----+------+ + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` .. note:: The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 3505065b648f2..0906c9c6b329a 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -236,6 +236,8 @@ def apply(self, udf): into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory. + .. note:: Experimental + :param udf: a grouped map user-defined function returned by :func:`pyspark.sql.functions.pandas_udf`. diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6bd79bc2f43e5..3efe2adb6e2a4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, + dropFieldIfAllNull=None, encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -237,8 +238,17 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param encoding: allows to forcibly set one of standard basic or extended encoding for + the JSON files. For example UTF-16BE, UTF-32LE. If None is set, + the encoding of input JSON will be detected automatically + when the multiLine option is set to ``true``. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + :param samplingRatio: defines fraction of input JSON objects used for schema inferring. + If None is set, it uses the default value, ``1.0``. + :param dropFieldIfAllNull: whether to ignore column of all null values or empty + array/struct during schema inference. If None is set, it + uses the default value, ``false``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -256,7 +266,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, + samplingRatio=samplingRatio, encoding=encoding) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -337,7 +348,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + samplingRatio=None, enforceSchema=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -364,6 +376,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. + :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be + forcibly applied to datasource files, and headers in CSV files will be + ignored. If the option is set to ``false``, the schema will be + validated against all headers in CSV files or the first header in RDD + if the ``header`` option is set to ``true``. Field names in the schema + and column names in CSV headers are checked by their positions + taking into account ``spark.sql.caseSensitive``. If None is set, + ``true`` is used by default. Though the default value is ``true``, + it is recommended to disable the ``enforceSchema`` option + to avoid incorrect results. :param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from values being read should be skipped. If None is set, it uses the default value, ``false``. @@ -420,6 +442,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non the quote character. If None is set, the default value is escape character when escape and quote characters are different, ``\0`` otherwise. + :param samplingRatio: defines fraction of rows used for schema inferring. + If None is set, it uses the default value, ``1.0``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -438,7 +462,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, + enforceSchema=enforceSchema) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -749,7 +774,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) @since(1.4) def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, - lineSep=None): + lineSep=None, encoding=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -773,6 +798,8 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param encoding: specifies encoding (charset) of saved json files. If None is set, + the default UTF-8 charset will be used. :param lineSep: defines the line separator that should be used for writing. If None is set, it uses the default value, ``\\n``. @@ -781,7 +808,7 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm self.mode(mode) self._set_opts( compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, - lineSep=lineSep) + lineSep=lineSep, encoding=encoding) self._jwrite.json(path) @since(1.4) @@ -966,7 +993,7 @@ def _test(): globs = pyspark.sql.readwriter.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') try: - spark = SparkSession.builder.enableHiveSupport().getOrCreate() + spark = SparkSession.builder.getOrCreate() except py4j.protocol.Py4JError: spark = SparkSession(sc) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 13d6e2e53dbd0..f1ad6b1212ed9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -547,6 +547,33 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): df._schema = schema return df + @staticmethod + def _create_shell_session(): + """ + Initialize a SparkSession for a pyspark shell session. This is called from shell.py + to make error handling simpler without needing to declare local variables in that + script, which would expose those to users. + """ + import py4j + from pyspark.conf import SparkConf + from pyspark.context import SparkContext + try: + # Try to access HiveConf, it will raise exception if Hive is not added + conf = SparkConf() + if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': + SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() + return SparkSession.builder\ + .enableHiveSupport()\ + .getOrCreate() + else: + return SparkSession.builder.getOrCreate() + except (py4j.protocol.Py4JError, TypeError): + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") + + return SparkSession.builder.getOrCreate() + @since(2.0) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): @@ -584,6 +611,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr .. versionchanged:: 2.1 Added verifySchema. + .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental. + >>> l = [('Alice', 1)] >>> spark.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 15f9407389864..8c1fd4af674d7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -24,12 +24,14 @@ else: intlike = (int, long) +from py4j.java_gateway import java_import + from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.types import * -from pyspark.sql.utils import StreamingQueryException +from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] @@ -564,7 +566,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + enforceSchema=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -592,6 +595,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. + :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be + forcibly applied to datasource files, and headers in CSV files will be + ignored. If the option is set to ``false``, the schema will be + validated against all headers in CSV files or the first header in RDD + if the ``header`` option is set to ``true``. Field names in the schema + and column names in CSV headers are checked by their positions + taking into account ``spark.sql.caseSensitive``. If None is set, + ``true`` is used by default. Though the default value is ``true``, + it is recommended to disable the ``enforceSchema`` option + to avoid incorrect results. :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from values being read should be skipped. If None is set, it uses the default value, ``false``. @@ -664,7 +677,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: @@ -843,6 +856,197 @@ def trigger(self, processingTime=None, once=None, continuous=None): self._jwrite = self._jwrite.trigger(jTrigger) return self + @since(2.4) + def foreach(self, f): + """ + Sets the output of the streaming query to be processed using the provided writer ``f``. + This is often used to write the output of a streaming query to arbitrary storage systems. + The processing logic can be specified in two ways. + + #. A **function** that takes a row as input. + This is a simple way to express your processing logic. Note that this does + not allow you to deduplicate generated data when failures cause reprocessing of + some input data. That would require you to specify the processing logic in the next + way. + + #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods. + The object can have the following methods. + + * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing + (for example, open a connection, start a transaction, etc). Additionally, you can + use the `partition_id` and `epoch_id` to deduplicate regenerated data + (discussed later). + + * ``process(row)``: *Non-optional* method that processes each :class:`Row`. + + * ``close(error)``: *Optional* method that finalizes and cleans up (for example, + close connection, commit transaction, etc.) after all rows have been processed. + + The object will be used by Spark in the following way. + + * A single copy of this object is responsible of all the data generated by a + single task in a query. In other words, one instance is responsible for + processing one partition of the data generated in a distributed manner. + + * This object must be serializable because each task will get a fresh + serialized-deserialized copy of the provided object. Hence, it is strongly + recommended that any initialization for writing data (e.g. opening a + connection or starting a transaction) is done after the `open(...)` + method has been called, which signifies that the task is ready to generate data. + + * The lifecycle of the methods are as follows. + + For each partition with ``partition_id``: + + ... For each batch/epoch of streaming data with ``epoch_id``: + + ....... Method ``open(partitionId, epochId)`` is called. + + ....... If ``open(...)`` returns true, for each row in the partition and + batch/epoch, method ``process(row)`` is called. + + ....... Method ``close(errorOrNull)`` is called with error (if any) seen while + processing rows. + + Important points to note: + + * The `partitionId` and `epochId` can be used to deduplicate generated data when + failures cause reprocessing of some input data. This depends on the execution + mode of the query. If the streaming query is being executed in the micro-batch + mode, then every partition represented by a unique tuple (partition_id, epoch_id) + is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used + to deduplicate and/or transactionally commit data and achieve exactly-once + guarantees. However, if the streaming query is being executed in the continuous + mode, then this guarantee does not hold and therefore should not be used for + deduplication. + + * The ``close()`` method (if exists) will be called if `open()` method exists and + returns successfully (irrespective of the return value), except if the Python + crashes in the middle. + + .. note:: Evolving. + + >>> # Print every row using a function + >>> def print_row(row): + ... print(row) + ... + >>> writer = sdf.writeStream.foreach(print_row) + >>> # Print every row using a object with process() method + >>> class RowPrinter: + ... def open(self, partition_id, epoch_id): + ... print("Opened %d, %d" % (partition_id, epoch_id)) + ... return True + ... def process(self, row): + ... print(row) + ... def close(self, error): + ... print("Closed with error: %s" % str(error)) + ... + >>> writer = sdf.writeStream.foreach(RowPrinter()) + """ + + from pyspark.rdd import _wrap_function + from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + from pyspark.taskcontext import TaskContext + + if callable(f): + # The provided object is a callable function that is supposed to be called on each row. + # Construct a function that takes an iterator and calls the provided function on each + # row. + def func_without_process(_, iterator): + for x in iterator: + f(x) + return iter([]) + + func = func_without_process + + else: + # The provided object is not a callable function. Then it is expected to have a + # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and + # 'close(error)' methods. + + if not hasattr(f, 'process'): + raise Exception("Provided object does not have a 'process' method") + + if not callable(getattr(f, 'process')): + raise Exception("Attribute 'process' in provided object is not callable") + + def doesMethodExist(method_name): + exists = hasattr(f, method_name) + if exists and not callable(getattr(f, method_name)): + raise Exception( + "Attribute '%s' in provided object is not callable" % method_name) + return exists + + open_exists = doesMethodExist('open') + close_exists = doesMethodExist('close') + + def func_with_open_process_close(partition_id, iterator): + epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId') + if epoch_id: + epoch_id = int(epoch_id) + else: + raise Exception("Could not get batch id from TaskContext") + + # Check if the data should be processed + should_process = True + if open_exists: + should_process = f.open(partition_id, epoch_id) + + error = None + + try: + if should_process: + for x in iterator: + f.process(x) + except Exception as ex: + error = ex + finally: + if close_exists: + f.close(error) + if error: + raise error + + return iter([]) + + func = func_with_open_process_close + + serializer = AutoBatchedSerializer(PickleSerializer()) + wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) + jForeachWriter = \ + self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( + wrapped_func, self._df._jdf.schema()) + self._jwrite.foreach(jForeachWriter) + return self + + @since(2.4) + def foreachBatch(self, func): + """ + Sets the output of the streaming query to be processed using the provided + function. This is supported only the in the micro-batch execution modes (that is, when the + trigger is not continuous). In every micro-batch, the provided function will be called in + every micro-batch with (i) the output rows as a DataFrame and (ii) the batch identifier. + The batchId can be used deduplicate and transactionally write the output + (that is, the provided Dataset) to external systems. The output DataFrame is guaranteed + to exactly same for the same batchId (assuming all operations are deterministic in the + query). + + .. note:: Evolving. + + >>> def func(batch_df, batch_id): + ... batch_df.collect() + ... + >>> writer = sdf.writeStream.foreach(func) + """ + + from pyspark.java_gateway import ensure_callback_server_started + gw = self._spark._sc._gateway + java_import(gw.jvm, "org.apache.spark.sql.execution.streaming.sources.*") + + wrapped_func = ForeachBatchFunction(self._spark, func) + gw.jvm.PythonForeachBatchHelper.callForeachBatch(self._jwrite, wrapped_func) + ensure_callback_server_started(gw) + return self + @ignore_unicode_prefix @since(2.0) def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4e99c8e3c6b10..565654e7f03bb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -685,6 +685,13 @@ def test_multiline_json(self): multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_encoding_json(self): + people_array = self.spark.read\ + .json("python/test_support/sql/people_array_utf16le.json", + multiLine=True, encoding="UTF-16LE") + expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')] + self.assertEqual(people_array.collect(), expected) + def test_linesep_json(self): df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") expected = [Row(_corrupt_record=None, name=u'Michael'), @@ -1862,6 +1869,299 @@ def test_query_manager_await_termination(self): q.stop() shutil.rmtree(tmpPath) + class ForeachWriterTester: + + def __init__(self, spark): + self.spark = spark + + def write_open_event(self, partitionId, epochId): + self._write_event( + self.open_events_dir, + {'partition': partitionId, 'epoch': epochId}) + + def write_process_event(self, row): + self._write_event(self.process_events_dir, {'value': 'text'}) + + def write_close_event(self, error): + self._write_event(self.close_events_dir, {'error': str(error)}) + + def write_input_file(self): + self._write_event(self.input_dir, "text") + + def open_events(self): + return self._read_events(self.open_events_dir, 'partition INT, epoch INT') + + def process_events(self): + return self._read_events(self.process_events_dir, 'value STRING') + + def close_events(self): + return self._read_events(self.close_events_dir, 'error STRING') + + def run_streaming_query_on_writer(self, writer, num_files): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + for i in range(num_files): + self.write_input_file() + sq.processAllAvailable() + finally: + self.stop_all() + + def assert_invalid_writer(self, writer, msg=None): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + self.write_input_file() + sq.processAllAvailable() + self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected + except Exception as e: + if msg: + assert msg in str(e), "%s not in %s" % (msg, str(e)) + + finally: + self.stop_all() + + def stop_all(self): + for q in self.spark._wrapped.streams.active: + q.stop() + + def _reset(self): + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + + def _read_events(self, dir, json): + rows = self.spark.read.schema(json).json(dir).collect() + dicts = [row.asDict() for row in rows] + return dicts + + def _write_event(self, dir, event): + import uuid + with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f: + f.write("%s\n" % str(event)) + + def __getstate__(self): + return (self.open_events_dir, self.process_events_dir, self.close_events_dir) + + def __setstate__(self, state): + self.open_events_dir, self.process_events_dir, self.close_events_dir = state + + def test_streaming_foreach_with_simple_function(self): + tester = self.ForeachWriterTester(self.spark) + + def foreach_func(row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(foreach_func, 2) + self.assertEqual(len(tester.process_events()), 2) + + def test_streaming_foreach_with_basic_open_process_close(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + open_events = tester.open_events() + self.assertEqual(len(open_events), 2) + self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) + + self.assertEqual(len(tester.process_events()), 2) + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_with_open_returning_false(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return False + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + self.assertEqual(len(tester.open_events()), 2) + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_without_open_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 2) + + def test_streaming_foreach_without_close_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return True + + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 2) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_without_open_and_close_methods(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.sql.utils import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + raise Exception("test error") + + def close(self, error): + tester.write_close_event(error) + + try: + tester.run_streaming_query_on_writer(ForeachWriter(), 1) + self.fail("bad writer did not fail the query") # this is not expected + except StreamingQueryException as e: + # TODO: Verify whether original error message is inside the exception + pass + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 1) + # TODO: Verify whether original error message is inside the exception + + def test_streaming_foreach_with_invalid_writers(self): + + tester = self.ForeachWriterTester(self.spark) + + def func_with_iterator_input(iter): + for x in iter: + print(x) + + tester.assert_invalid_writer(func_with_iterator_input) + + class WriterWithoutProcess: + def open(self, partition): + pass + + tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") + + class WriterWithNonCallableProcess(): + process = True + + tester.assert_invalid_writer(WriterWithNonCallableProcess(), + "'process' in provided object is not callable") + + class WriterWithNoParamProcess(): + def process(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamProcess()) + + # Abstract class for tests below + class WithProcess(): + def process(self, row): + pass + + class WriterWithNonCallableOpen(WithProcess): + open = True + + tester.assert_invalid_writer(WriterWithNonCallableOpen(), + "'open' in provided object is not callable") + + class WriterWithNoParamOpen(WithProcess): + def open(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamOpen()) + + class WriterWithNonCallableClose(WithProcess): + close = True + + tester.assert_invalid_writer(WriterWithNonCallableClose(), + "'close' in provided object is not callable") + + def test_streaming_foreachBatch(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + collected[batch_id] = batch_df.collect() + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_propagates_python_errors(self): + from pyspark.sql.utils import StreamingQueryException + + q = None + + def collectBatch(df, id): + raise Exception("this should fail the query") + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.fail("Expected a failure") + except StreamingQueryException as e: + self.assertTrue("this should fail" in str(e)) + finally: + if q: + q.stop() + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) @@ -3018,9 +3318,134 @@ def test_sort_with_nulls_order(self): df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) + def test_json_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x)) + schema = self.spark.read.option('inferSchema', True) \ + .option('samplingRatio', 0.5) \ + .json(rdd).schema + self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) + + def test_csv_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '0.1' if x == 1 else str(x)) + schema = self.spark.read.option('inferSchema', True)\ + .csv(rdd, samplingRatio=0.5).schema + self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) + + def test_checking_csv_header(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + self.spark.createDataFrame([[1, 1000], [2000, 2]])\ + .toDF('f1', 'f2').write.option("header", "true").csv(path) + schema = StructType([ + StructField('f2', IntegerType(), nullable=True), + StructField('f1', IntegerType(), nullable=True)]) + df = self.spark.read.option('header', 'true').schema(schema)\ + .csv(path, enforceSchema=False) + self.assertRaisesRegexp( + Exception, + "CSV header does not conform to the schema", + lambda: df.collect()) + finally: + shutil.rmtree(path) + + def test_repr_behaviors(self): + import re + pattern = re.compile(r'^ *\|', re.MULTILINE) + df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) + + # test when eager evaluation is enabled and _repr_html_ will not be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """+-----+-----+ + || key|value| + |+-----+-----+ + || 1| 1| + ||22222|22222| + |+-----+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected1), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + ||222| 222| + |+---+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected2), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + |+---+-----+ + |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df.__repr__()) + + # test when eager evaluation is enabled and _repr_html_ will be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """ + | + | + | + |
      keyvalue
      11
      2222222222
      + |""" + self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """ + | + | + | + |
      keyvalue
      11
      222222
      + |""" + self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """ + | + | + |
      keyvalue
      11
      + |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) + + # test when eager evaluation is disabled and _repr_html_ will be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}): + expected = "DataFrame[key: bigint, value: string]" + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + class HiveSparkSubmitTests(SparkSubmitTests): + @classmethod + def setUpClass(cls): + # get a SparkContext to check for availability of Hive + sc = SparkContext('local[4]', cls.__name__) + cls.hive_available = True + try: + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.hive_available = False + except TypeError: + cls.hive_available = False + finally: + # we don't need this SparkContext for the test + sc.stop() + + def setUp(self): + super(HiveSparkSubmitTests, self).setUp() + if not self.hive_available: + self.skipTest("Hive is not available.") + def test_hivecontext(self): # This test checks that HiveContext is using Hive metastore (SPARK-16224). # It sets a metastore url and checks if there is a derby dir created by @@ -3050,8 +3475,8 @@ def test_hivecontext(self): |print(hive_context.sql("show databases").collect()) """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", - "--driver-class-path", hive_site_dir, script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", + "--driver-class-path", hive_site_dir, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -3088,23 +3513,28 @@ def setUpClass(cls): filename_pattern = ( "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" "TestQueryExecutionListener.class") - if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): - raise unittest.SkipTest( + cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))) + + if cls.has_listener: + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + def setUp(self): + if not self.has_listener: + raise self.skipTest( "'org.apache.spark.sql.TestQueryExecutionListener' is not " "available. Will skip the related tests.") - # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. - cls.spark = SparkSession.builder \ - .master("local[4]") \ - .appName(cls.__name__) \ - .config( - "spark.sql.queryExecutionListeners", - "org.apache.spark.sql.TestQueryExecutionListener") \ - .getOrCreate() - @classmethod def tearDownClass(cls): - cls.spark.stop() + if hasattr(cls, "spark"): + cls.spark.stop() def tearDown(self): self.spark._jvm.OnSuccessCall.clear() @@ -3188,18 +3618,22 @@ class HiveContextSQLTests(ReusedPySparkTestCase): def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + cls.hive_available = True try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False os.unlink(cls.tempdir.name) - cls.spark = HiveContext._createForTesting(cls.sc) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.df = cls.sc.parallelize(cls.testData).toDF() + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -4029,6 +4463,61 @@ def foo(df): def foo(k, v, w): return k + def test_stopiteration_in_udf(self): + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + def foofoo(x, y): + raise StopIteration() + + exc_message = "Caught StopIteration thrown from user's code; failing the task" + df = self.spark.range(0, 100) + + # plain udf (test for SPARK-23754) + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn('v', udf(foo)('id')).collect + ) + + # pandas scalar udf + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn( + 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') + ).collect + ) + + # pandas grouped map + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + + # pandas grouped agg + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').agg( + pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') + ).collect + ) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -4295,7 +4784,6 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, @@ -4629,6 +5117,26 @@ def test_supported_types(self): self.assertPandasEqual(expected2, result2) self.assertPandasEqual(expected3, result3) + def test_array_type_correct(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + + df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType()))]) + + udf = pandas_udf( + lambda pdf: pdf, + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result = df.groupby('id').apply(udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True) + self.assertPandasEqual(expected, result) + def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4860,6 +5368,125 @@ def foo3(key, pdf): expected4 = udf3.func((), pdf) self.assertPandasEqual(expected4, result4) + def test_column_order(self): + from collections import OrderedDict + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + # Helper function to set column names from a list + def rename_pdf(pdf, names): + pdf.rename(columns={old: new for old, new in + zip(pd_result.columns, names)}, inplace=True) + + df = self.data + grouped_df = df.groupby('id') + grouped_pdf = df.toPandas().groupby('id') + + # Function returns a pdf with required column names, but order could be arbitrary using dict + def change_col_order(pdf): + # Constructing a DataFrame from a dict should result in the same order, + # but use from_items to ensure the pdf column order is different than schema + return pd.DataFrame.from_items([ + ('id', pdf.id), + ('u', pdf.v * 2), + ('v', pdf.v)]) + + ordered_udf = pandas_udf( + change_col_order, + 'id long, v int, u int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by name from the pdf + result = grouped_df.apply(ordered_udf).sort('id', 'v')\ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(change_col_order) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with positional columns, indexed by range + def range_col_order(pdf): + # Create a DataFrame with positional columns, fix types to long + return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype='int64') + + range_udf = pandas_udf( + range_col_order, + 'id long, u long, v long', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result uses positional columns from the pdf + result = grouped_df.apply(range_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(range_col_order) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with columns indexed with integers + def int_index(pdf): + return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)])) + + int_index_udf = pandas_udf( + int_index, + 'id long, u int, v int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by position of integer index + result = grouped_df.apply(int_index_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(int_index) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def column_name_typo(pdf): + return pd.DataFrame({'iid': pdf.id, 'v': pdf.v}) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def invalid_positional_types(pdf): + return pd.DataFrame([(u'a', 1.2)]) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): + grouped_df.apply(column_name_typo).collect() + with self.assertRaisesRegexp(Exception, "No cast implemented"): + grouped_df.apply(invalid_positional_types).collect() + + def test_positional_assignment_conf(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with self.sql_conf({"spark.sql.execution.pandas.groupedMap.assignColumnsByPosition": True}): + + @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP) + def foo(_): + return pd.DataFrame([('hi', 1)], columns=['x', 'y']) + + df = self.data + result = df.groupBy('id').apply(foo).select('a', 'b').collect() + for r in result: + self.assertEqual(r.a, 'hi') + self.assertEqual(r.b, 1) + + def test_self_join_with_pandas(self): + import pyspark.sql.functions as F + + @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + def dummy_pandas_udf(df): + return df[['key', 'col']] + + df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), + Row(key=2, col='C')]) + df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf) + + # this was throwing an AnalysisException before SPARK-24208 + res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'), + F.col('temp0.key') == F.col('temp1.key')) + self.assertEquals(res.count(), 5) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -5168,8 +5795,8 @@ def test_complex_groupby(self): expected2 = df.groupby().agg(sum(df.v)) # groupby one column and one sql expression - result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) - expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)) + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2) # groupby one python UDF result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) @@ -5280,6 +5907,15 @@ def test_retain_group_columns(self): expected1 = df.groupby(df.id).agg(sum(df.v)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) + result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2')) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + def test_invalid_args(self): from pyspark.sql.functions import mean @@ -5305,9 +5941,238 @@ def test_invalid_args(self): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) +class WindowPandasUDFTests(ReusedSQLTestCase): + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + return udf(lambda v: v + 1, 'double') + + @property + def pandas_scalar_time_two(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + return pandas_udf(lambda v: v * 2, 'double') + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_max_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def max(v): + return v.max() + return max + + @property + def pandas_agg_min_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def min(v): + return v.min() + return min + + @property + def unbounded_window(self): + return Window.partitionBy('id') \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + + @property + def ordered_window(self): + return Window.partitionBy('id').orderBy('v') + + @property + def unpartitioned_window(self): + return Window.partitionBy() + + def test_simple(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max + + df = self.data + w = self.unbounded_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_multiple_udfs(self): + from pyspark.sql.functions import max, min, mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \ + .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ + .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) + + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \ + .withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('min_w', min(df['w']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_replace_existing(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v', mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1) + expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_udf(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + plus_one = self.python_plus_one + time_two = self.pandas_scalar_time_two + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn( + 'v2', + plus_one(mean_udf(plus_one(df['v'])).over(w))) + expected1 = df.withColumn( + 'v2', + plus_one(mean(plus_one(df['v'])).over(w))) + + result2 = df.withColumn( + 'v2', + time_two(mean_udf(time_two(df['v'])).over(w))) + expected2 = df.withColumn( + 'v2', + time_two(mean(time_two(df['v'])).over(w))) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_without_partitionBy(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unpartitioned_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v2', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v2', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_mixed_sql_and_udf(self): + from pyspark.sql.functions import max, min, rank, col + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + max_udf = self.pandas_agg_max_udf + min_udf = self.pandas_agg_min_udf + + result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w)) + expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w)) + + # Test mixing sql window function and window udf in the same expression + result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w)) + expected2 = expected1 + + # Test chaining sql aggregate function and udf + result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('min_v', min(df['v']).over(w)) \ + .withColumn('v_diff', col('max_v') - col('min_v')) \ + .drop('max_v', 'min_v') + expected3 = expected1 + + # Test mixing sql window function and udf + result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + expected4 = df.withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) + result1 = df.withColumn('v2', array_udf(df['v']).over(w)) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + + def test_invalid_args(self): + from pyspark.sql.functions import mean, pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*not supported within a window function'): + foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) + df.withColumn('v2', foo_udf(df['v']).over(w)) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*Only unbounded window frame is supported.*'): + df.withColumn('mean_v', mean_udf(df['v']).over(ow)) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1f6534836d64a..3cd7a2ef115af 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -289,7 +289,8 @@ def __init__(self, elementType, containsNull=True): >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ - assert isinstance(elementType, DataType), "elementType should be DataType" + assert isinstance(elementType, DataType),\ + "elementType %s should be an instance of %s" % (elementType, DataType) self.elementType = elementType self.containsNull = containsNull @@ -343,8 +344,10 @@ def __init__(self, keyType, valueType, valueContainsNull=True): ... == MapType(StringType(), FloatType())) False """ - assert isinstance(keyType, DataType), "keyType should be DataType" - assert isinstance(valueType, DataType), "valueType should be DataType" + assert isinstance(keyType, DataType),\ + "keyType %s should be an instance of %s" % (keyType, DataType) + assert isinstance(valueType, DataType),\ + "valueType %s should be an instance of %s" % (valueType, DataType) self.keyType = keyType self.valueType = valueType self.valueContainsNull = valueContainsNull @@ -402,8 +405,9 @@ def __init__(self, name, dataType, nullable=True, metadata=None): ... == StructField("f2", StringType(), True)) False """ - assert isinstance(dataType, DataType), "dataType should be DataType" - assert isinstance(name, basestring), "field name should be string" + assert isinstance(dataType, DataType),\ + "dataType %s should be an instance of %s" % (dataType, DataType) + assert isinstance(name, basestring), "field name %s should be string" % (name) if not isinstance(name, str): name = name.encode('utf-8') self.name = name diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 45363f089a73d..bb9ce02c4b60f 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -150,3 +150,26 @@ def require_minimum_pyarrow_version(): if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): raise ImportError("PyArrow >= %s must be installed; however, " "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) + + +class ForeachBatchFunction(object): + """ + This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps + the user-defined 'foreachBatch' function such that it can be called from the JVM when + the query is active. + """ + + def __init__(self, sql_ctx, func): + self.sql_ctx = sql_ctx + self.func = func + + def call(self, jdf, batch_id): + from pyspark.sql.dataframe import DataFrame + try: + self.func(DataFrame(jdf, self.sql_ctx), batch_id) + except Exception as e: + self.error = e + raise e + + class Java: + implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 17c34f8a1c54c..a4515828d180c 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -79,22 +79,8 @@ def _ensure_initialized(cls): java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") - # start callback server - # getattr will fallback to JVM, so we cannot test by hasattr() - if "_callback_server" not in gw.__dict__ or gw._callback_server is None: - gw.callback_server_parameters.eager_load = True - gw.callback_server_parameters.daemonize = True - gw.callback_server_parameters.daemonize_connections = True - gw.callback_server_parameters.port = 0 - gw.start_callback_server(gw.callback_server_parameters) - cbport = gw._callback_server.server_socket.getsockname()[1] - gw._callback_server.port = cbport - # gateway with real port - gw._python_proxy_port = gw._callback_server.port - # get the GatewayServer object in JVM by ID - jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) - # update the port of CallbackClient with real port - jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) + from pyspark.java_gateway import ensure_callback_server_started + ensure_callback_server_started(gw) # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing @@ -338,7 +324,7 @@ def transform(self, dstreams, transformFunc): jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction(self._sc, - lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + lambda t, *rdds: transformFunc(rdds), *[d._jrdd_deserializer for d in dstreams]) jfunc = self._jvm.TransformFunction(func) jdstream = self._jssc.transform(jdstreams, jfunc) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 103940923dd4d..373784f826677 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -63,7 +63,7 @@ def setUpClass(cls): class_name = cls.__name__ conf = SparkConf().set("spark.default.parallelism", 1) cls.sc = SparkContext(appName=class_name, conf=conf) - cls.sc.setCheckpointDir("/tmp") + cls.sc.setCheckpointDir(tempfile.mkdtemp()) @classmethod def tearDownClass(cls): @@ -779,6 +779,12 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) + def test_transform_pairrdd(self): + # This regression test case is for SPARK-17756. + dstream = self.ssc.queueStream( + [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd)) + self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3)) + def test_get_active(self): self.assertEqual(StreamingContext.getActive(), None) @@ -1549,7 +1555,9 @@ def search_kinesis_asl_assembly_jar(): kinesis_jar_present = True jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) - os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % jars + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, StreamingListenerTests] @@ -1590,11 +1598,11 @@ def search_kinesis_asl_assembly_jar(): sys.stderr.write("[Running %s]\n" % (testcase)) tests = unittest.TestLoader().loadTestsFromTestCase(testcase) if xmlrunner: - result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=3).run(tests) + result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2).run(tests) if not result.wasSuccessful(): failed = True else: - result = unittest.TextTestRunner(verbosity=3).run(tests) + result = unittest.TextTestRunner(verbosity=2).run(tests) if not result.wasSuccessful(): failed = True sys.exit(failed) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index df184471993ff..b4b9f97feb7ca 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -20,6 +20,8 @@ import traceback import sys +from py4j.java_gateway import is_instance_of + from pyspark import SparkContext, RDD @@ -65,7 +67,14 @@ def call(self, milliseconds, jrdds): t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) if r: - return r._jrdd + # Here, we work around to ensure `_jrdd` is `JavaRDD` by wrapping it by `map`. + # org.apache.spark.streaming.api.python.PythonTransformFunction requires to return + # `JavaRDD`; however, this could be `JavaPairRDD` by some APIs, for example, `zip`. + # See SPARK-17756. + if is_instance_of(self.ctx._gateway, r._jrdd, "org.apache.spark.api.java.JavaRDD"): + return r._jrdd + else: + return r.map(lambda x: x)._jrdd except: self.failure = traceback.format_exc() diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index e5218d9e75e78..63ae1f30e17ca 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -34,6 +34,7 @@ class TaskContext(object): _partitionId = None _stageId = None _taskAttemptId = None + _localProperties = None def __new__(cls): """Even if users construct TaskContext instead of using get, give them the singleton.""" @@ -88,3 +89,9 @@ def taskAttemptId(self): TaskAttemptID. """ return self._taskAttemptId + + def getLocalProperty(self, key): + """ + Get a local property set upstream in the driver, or None if it is missing. + """ + return self._localProperties.get(key, None) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 9111dbbed5929..a4c5fb1db8b37 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -161,6 +161,37 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -543,6 +574,20 @@ def test_tc_on_driver(self): tc = TaskContext.get() self.assertTrue(tc is None) + def test_get_local_property(self): + """Verify that local properties set on the driver are available in TaskContext.""" + key = "testkey" + value = "testvalue" + self.sc.setLocalProperty(key, value) + try: + rdd = self.sc.parallelize(range(1), 1) + prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] + self.assertEqual(prop1, value) + prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + self.assertTrue(prop2 is None) + finally: + self.sc.setLocalProperty(key, None) + class RDDTests(ReusedPySparkTestCase): @@ -1246,6 +1291,35 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) + def test_stopiteration_in_user_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + msg = "Caught StopIteration thrown from user's code; failing the task" + + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, + seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + + # these methods call the user function both in the driver and in the executor + # the exception raised is different according to where the StopIteration happens + # RuntimeError is raised if in the driver + # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + class ProfilerTests(PySparkTestCase): @@ -1951,7 +2025,12 @@ class SparkSubmitTests(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() - self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit") + tmp_dir = tempfile.gettempdir() + self.sparkSubmit = [ + os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + ] def tearDown(self): shutil.rmtree(self.programDir) @@ -2017,7 +2096,7 @@ def test_single_script(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 4, 6]", out.decode('utf-8')) @@ -2033,7 +2112,7 @@ def test_script_with_local_functions(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[3, 6, 9]", out.decode('utf-8')) @@ -2051,7 +2130,7 @@ def test_module_dependency(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script], + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2070,7 +2149,7 @@ def test_module_dependency_on_cluster(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() @@ -2087,8 +2166,10 @@ def test_package_dependency(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2103,9 +2184,11 @@ def test_package_dependency_on_cluster(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, "--master", - "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", + script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2124,7 +2207,7 @@ def test_single_script_on_cluster(self): # this will fail if you have different spark.executor.memory # in conf/spark-defaults.conf proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2144,7 +2227,7 @@ def test_user_configuration(self): | sc.stop() """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local", script], + self.sparkSubmit + ["--master", "local", script], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) out, err = proc.communicate() @@ -2303,6 +2386,10 @@ def test_py4j_exception_message(self): self.assertTrue('NullPointerException' in _exception_message(context.exception)) + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): @@ -2353,15 +2440,7 @@ def test_statcounter_array(self): if __name__ == "__main__": from pyspark.tests import * - if not _have_scipy: - print("NOTE: Skipping SciPy tests as it does not seem to be installed") - if not _have_numpy: - print("NOTE: Skipping NumPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() - if not _have_scipy: - print("NOTE: SciPy tests were skipped as it does not seem to be installed") - if not _have_numpy: - print("NOTE: NumPy tests were skipped as it does not seem to be installed") + unittest.main(verbosity=2) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 49afc13640332..f015542c8799d 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -16,6 +16,7 @@ # limitations under the License. # +import re import sys import inspect from py4j.protocol import Py4JJavaError @@ -52,15 +53,59 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. if sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. argspec = inspect.getfullargspec(f) return argspec +class VersionUtils(object): + """ + Provides utility method to determine Spark versions with given input string. + """ + @staticmethod + def majorMinorVersion(sparkVersion): + """ + Given a Spark version string, return the (major version number, minor version number). + E.g., for 2.0.1-SNAPSHOT, return (2, 0). + + >>> sparkVersion = "2.4.0" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 4) + >>> sparkVersion = "2.3.0-SNAPSHOT" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 3) + + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion) + if m is not None: + return (int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Spark tried to parse '%s' as a Spark" % sparkVersion + + " version string, but it could not find the major and minor" + + " version numbers.") + + +def fail_on_stopiteration(f): + """ + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop in Spark code + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError( + "Caught StopIteration thrown from user's code; failing the task", + exc + ) + + return wrapper + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a1a4336b1e8de..eaaae2b14e107 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -27,6 +27,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.java_gateway import do_server_auth from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType @@ -34,9 +35,12 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle +if sys.version >= '3': + basestring = str + pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -81,7 +85,7 @@ def wrap_scalar_pandas_udf(f, return_type): def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " + raise TypeError("Return type of the user-defined function should be " "Pandas.Series, but is {}".format(type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " @@ -91,10 +95,12 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type): +def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): + assign_cols_by_pos = runner_conf.get( + "spark.sql.execution.pandas.groupedMap.assignColumnsByPosition", False) + def wrapped(key_series, value_series): import pandas as pd - argspec = _get_argspec(f) if len(argspec.args) == 1: result = f(pd.concat(value_series, axis=1)) @@ -110,9 +116,13 @@ def wrapped(key_series, value_series): "Number of columns of the returned pandas.DataFrame " "doesn't match specified schema. " "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) - arrow_return_types = (to_arrow_type(field.dataType) for field in return_type) - return [(result[result.columns[i]], arrow_type) - for i, arrow_type in enumerate(arrow_return_types)] + + # Assign result columns by schema name if user labeled with strings, else use position + if not assign_cols_by_pos and any(isinstance(name, basestring) for name in result.columns): + return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type] + else: + return [(result[result.columns[i]], to_arrow_type(field.dataType)) + for i, field in enumerate(return_type)] return wrapped @@ -128,7 +138,22 @@ def wrapped(*series): return lambda *a: (wrapped(*a), arrow_return_type) -def read_single_udf(pickleSer, infile, eval_type): +def wrap_window_agg_pandas_udf(f, return_type): + # This is similar to grouped_agg_pandas_udf, the only difference + # is that window_agg_pandas_udf needs to repeat the return value + # to match window length, where grouped_agg_pandas_udf just returns + # the scalar value. + arrow_return_type = to_arrow_type(return_type) + + def wrapped(*series): + import pandas as pd + result = f(*series) + return pd.Series([result]).repeat(len(series[0])) + + return lambda *a: (wrapped(*a), arrow_return_type) + + +def read_single_udf(pickleSer, infile, eval_type, runner_conf): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None @@ -139,20 +164,47 @@ def read_single_udf(pickleSer, infile, eval_type): else: row_func = chain(row_func, f) + # make sure StopIteration's raised in the user code are not ignored + # when they are processed in a for loop, raise them as RuntimeError's instead + func = fail_on_stopiteration(row_func) + # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: - return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) + return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: - return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) + argspec = _get_argspec(row_func) # signature was lost when wrapping it + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: - return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type) + return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) + elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: + return arg_offsets, wrap_window_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: - return arg_offsets, wrap_udf(row_func, return_type) + return arg_offsets, wrap_udf(func, return_type) else: raise ValueError("Unknown eval type: {}".format(eval_type)) def read_udfs(pickleSer, infile, eval_type): + runner_conf = {} + + if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): + + # Load conf used for pandas_udf evaluation + num_conf = read_int(infile) + for i in range(num_conf): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + runner_conf[k] = v + + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True + timezone = runner_conf.get("spark.sql.session.timeZone", None) + ser = ArrowStreamPandasSerializer(timezone) + else: + ser = BatchedSerializer(PickleSerializer(), 100) + num_udfs = read_int(infile) udfs = {} call_udf = [] @@ -167,7 +219,7 @@ def read_udfs(pickleSer, infile, eval_type): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) udfs['f'] = udf split_offset = arg_offsets[0] + 1 arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] @@ -179,7 +231,7 @@ def read_udfs(pickleSer, infile, eval_type): # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) @@ -188,14 +240,6 @@ def read_udfs(pickleSer, infile, eval_type): mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) - if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): - timezone = utf8_deserializer.loads(infile) - ser = ArrowStreamPandasSerializer(timezone) - else: - ser = BatchedSerializer(PickleSerializer(), 100) - # profiling is not supported for UDF return func, None, ser, ser @@ -221,6 +265,12 @@ def main(infile, outfile): taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) + taskContext._localProperties = dict() + for i in range(read_int(infile)): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + taskContext._localProperties[k] = v + shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() @@ -301,9 +351,11 @@ def process(): if __name__ == '__main__': - # Read a local port to connect to from stdin - java_port = int(sys.stdin.readline()) + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", java_port)) sock_file = sock.makefile("rwb", 65536) + do_server_auth(sock_file, auth_secret) main(sock_file, sock_file) diff --git a/python/run-tests.py b/python/run-tests.py index 6b41b5ee22814..4c90926cfa350 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -22,16 +22,19 @@ from optparse import OptionParser import os import re +import shutil import subprocess import sys import tempfile from threading import Thread, Lock import time +import uuid if sys.version < '3': import Queue else: import queue as Queue from distutils.version import LooseVersion +from multiprocessing import Manager # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -50,6 +53,7 @@ def print_red(text): print('\033[31m' + text + '\033[0m') +SKIPPED_TESTS = Manager().dict() LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") FAILURE_REPORTING_LOCK = Lock() LOGGER = logging.getLogger() @@ -66,7 +70,7 @@ def print_red(text): raise Exception("Cannot find assembly build directory, please build Spark first.") -def run_individual_python_test(test_name, pyspark_python): +def run_individual_python_test(target_dir, test_name, pyspark_python): env = dict(os.environ) env.update({ 'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH, @@ -75,6 +79,23 @@ def run_individual_python_test(test_name, pyspark_python): 'PYSPARK_PYTHON': which(pyspark_python), 'PYSPARK_DRIVER_PYTHON': which(pyspark_python) }) + + # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is + # recognized by the tempfile module to override the default system temp directory. + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + while os.path.isdir(tmp_dir): + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + os.mkdir(tmp_dir) + env["TMPDIR"] = tmp_dir + + # Also override the JVM's temp directory by setting driver and executor options. + spark_args = [ + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "pyspark-shell" + ] + env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args) + LOGGER.info("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: @@ -82,6 +103,7 @@ def run_individual_python_test(test_name, pyspark_python): retcode = subprocess.Popen( [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], stderr=per_test_output, stdout=per_test_output, env=env).wait() + shutil.rmtree(tmp_dir, ignore_errors=True) except: LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if @@ -109,8 +131,34 @@ def run_individual_python_test(test_name, pyspark_python): # this code is invoked from a thread other than the main thread. os._exit(-1) else: - per_test_output.close() - LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) + skipped_counts = 0 + try: + per_test_output.seek(0) + # Here expects skipped test output from unittest when verbosity level is + # 2 (or --verbose option is enabled). + decoded_lines = map(lambda line: line.decode(), iter(per_test_output)) + skipped_tests = list(filter( + lambda line: re.search('test_.* \(pyspark\..*\) ... skipped ', line), + decoded_lines)) + skipped_counts = len(skipped_tests) + if skipped_counts > 0: + key = (pyspark_python, test_name) + SKIPPED_TESTS[key] = skipped_tests + per_test_output.close() + except: + import traceback + print_red("\nGot an exception while trying to store " + "skipped test output:\n%s" % traceback.format_exc()) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) + if skipped_counts != 0: + LOGGER.info( + "Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name, + duration, skipped_counts) + else: + LOGGER.info( + "Finished test(%s): %s (%is)", pyspark_python, test_name, duration) def get_default_python_executables(): @@ -152,65 +200,17 @@ def parse_opts(): return opts -def _check_dependencies(python_exec, modules_to_test): - if "COVERAGE_PROCESS_START" in os.environ: - # Make sure if coverage is installed. - try: - subprocess_check_output( - [python_exec, "-c", "import coverage"], - stderr=open(os.devnull, 'w')) - except: - print_red("Coverage is not installed in Python executable '%s' " - "but 'COVERAGE_PROCESS_START' environment variable is set, " - "exiting." % python_exec) - sys.exit(-1) - - # If we should test 'pyspark-sql', it checks if PyArrow and Pandas are installed and - # explicitly prints out. See SPARK-23300. - if pyspark_sql in modules_to_test: - # TODO(HyukjinKwon): Relocate and deduplicate these version specifications. - minimum_pyarrow_version = '0.8.0' - minimum_pandas_version = '0.19.2' - - try: - pyarrow_version = subprocess_check_output( - [python_exec, "-c", "import pyarrow; print(pyarrow.__version__)"], - universal_newlines=True, - stderr=open(os.devnull, 'w')).strip() - if LooseVersion(pyarrow_version) >= LooseVersion(minimum_pyarrow_version): - LOGGER.info("Will test PyArrow related features against Python executable " - "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) - else: - LOGGER.warning( - "Will skip PyArrow related features against Python executable " - "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " - "%s was found." % ( - python_exec, pyspark_sql.name, minimum_pyarrow_version, pyarrow_version)) - except: - LOGGER.warning( - "Will skip PyArrow related features against Python executable " - "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " - "was not found." % (python_exec, pyspark_sql.name, minimum_pyarrow_version)) - - try: - pandas_version = subprocess_check_output( - [python_exec, "-c", "import pandas; print(pandas.__version__)"], - universal_newlines=True, - stderr=open(os.devnull, 'w')).strip() - if LooseVersion(pandas_version) >= LooseVersion(minimum_pandas_version): - LOGGER.info("Will test Pandas related features against Python executable " - "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) - else: - LOGGER.warning( - "Will skip Pandas related features against Python executable " - "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " - "%s was found." % ( - python_exec, pyspark_sql.name, minimum_pandas_version, pandas_version)) - except: - LOGGER.warning( - "Will skip Pandas related features against Python executable " - "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " - "was not found." % (python_exec, pyspark_sql.name, minimum_pandas_version)) +def _check_coverage(python_exec): + # Make sure if coverage is installed. + try: + subprocess_check_output( + [python_exec, "-c", "import coverage"], + stderr=open(os.devnull, 'w')) + except: + print_red("Coverage is not installed in Python executable '%s' " + "but 'COVERAGE_PROCESS_START' environment variable is set, " + "exiting." % python_exec) + sys.exit(-1) def main(): @@ -237,9 +237,10 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: - # Check if the python executable has proper dependencies installed to run tests - # for given modules properly. - _check_dependencies(python_exec, modules_to_test) + # Check if the python executable has coverage installed when 'COVERAGE_PROCESS_START' + # environmental variable is set. + if "COVERAGE_PROCESS_START" in os.environ: + _check_coverage(python_exec) python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], @@ -257,6 +258,11 @@ def main(): priority = 100 task_queue.put((priority, (python_exec, test_goal))) + # Create the target directory before starting tasks to avoid races. + target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target')) + if not os.path.isdir(target_dir): + os.mkdir(target_dir) + def process_queue(task_queue): while True: try: @@ -264,7 +270,7 @@ def process_queue(task_queue): except Queue.Empty: break try: - run_individual_python_test(test_goal, python_exec) + run_individual_python_test(target_dir, test_goal, python_exec) finally: task_queue.task_done() @@ -281,6 +287,12 @@ def process_queue(task_queue): total_duration = time.time() - start_time LOGGER.info("Tests passed in %i seconds", total_duration) + for key, lines in sorted(SKIPPED_TESTS.items()): + pyspark_python, test_name = key + LOGGER.info("\nSkipped tests in %s with %s:" % (test_name, pyspark_python)) + for line in lines: + LOGGER.info(" %s" % line.rstrip()) + if __name__ == "__main__": main() diff --git a/python/setup.py b/python/setup.py index 794ceceae3008..45eb74eb87ce7 100644 --- a/python/setup.py +++ b/python/setup.py @@ -201,7 +201,7 @@ def _supports_symlinks(): 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, scripts=scripts, license='http://www.apache.org/licenses/LICENSE-2.0', - install_requires=['py4j==0.10.6'], + install_requires=['py4j==0.10.7'], setup_requires=['pypandoc'], extras_require={ 'ml': ['numpy>=1.7'], @@ -219,6 +219,7 @@ def _supports_symlinks(): 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy'] ) diff --git a/python/test_support/sql/people_array_utf16le.json b/python/test_support/sql/people_array_utf16le.json new file mode 100644 index 0000000000000..9c657fa30ac9c Binary files /dev/null and b/python/test_support/sql/people_array_utf16le.json differ diff --git a/repl/pom.xml b/repl/pom.xml index 6f4a863c48bc7..861bbd7c49654 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -102,7 +102,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded @@ -166,7 +166,7 @@
      - + scala-2.12 diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index e69441a475e9a..a44051b351e19 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -36,7 +36,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this() = this(None, new JPrintWriter(Console.out, true)) override def createInterpreter(): Unit = { - intp = new SparkILoopInterpreter(settings, out) + intp = new SparkILoopInterpreter(settings, out, initializeSpark) } val initializationCommands: Seq[String] = Seq( @@ -73,11 +73,15 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) "import org.apache.spark.sql.functions._" ) - def initializeSpark() { - intp.beQuietDuring { - savingReplayStack { // remove the commands from session history. - initializationCommands.foreach(processLine) + def initializeSpark(): Unit = { + if (!intp.reporter.hasErrors) { + // `savingReplayStack` removes the commands from session history. + savingReplayStack { + initializationCommands.foreach(intp quietRun _) } + } else { + throw new RuntimeException(s"Scala $versionString interpreter encountered " + + "errors during initialization") } } @@ -101,16 +105,6 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) /** Available commands */ override def commands: List[LoopCommand] = standardCommands - /** - * We override `loadFiles` because we need to initialize Spark *before* the REPL - * sees any files, so that the Spark context is visible in those files. This is a bit of a - * hack, but there isn't another hook available to us at this point. - */ - override def loadFiles(settings: Settings): Unit = { - initializeSpark() - super.loadFiles(settings) - } - override def resetCommand(line: String): Unit = { super.resetCommand(line) initializeSpark() diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala index e736607a9a6b9..4e63816402a10 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala @@ -21,8 +21,22 @@ import scala.collection.mutable import scala.tools.nsc.Settings import scala.tools.nsc.interpreter._ -class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) { - self => +class SparkILoopInterpreter(settings: Settings, out: JPrintWriter, initializeSpark: () => Unit) + extends IMain(settings, out) { self => + + /** + * We override `initializeSynchronous` to initialize Spark *after* `intp` is properly initialized + * and *before* the REPL sees any files in the private `loadInitFiles` functions, so that + * the Spark context is visible in those files. + * + * This is a bit of a hack, but there isn't another hook available to us at this point. + * + * See the discussion in Scala community https://github.com/scala/bug/issues/10913 for detail. + */ + override def initializeSynchronous(): Unit = { + super.initializeSynchronous() + initializeSpark() + } override lazy val memberHandlers = new { val intp: self.type = self diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 4dc399827ffed..42298b06a2c86 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -22,8 +22,8 @@ import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.xbean.asm5._ -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6._ +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala index cc76a703bdf8f..e4ddcef9772e4 100644 --- a/repl/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala @@ -44,6 +44,7 @@ object Main extends Logging { var interp: SparkILoop = _ private var hasErrors = false + private var isShellSession = false private def scalaOptionError(msg: String): Unit = { hasErrors = true @@ -53,6 +54,7 @@ object Main extends Logging { } def main(args: Array[String]) { + isShellSession = true doMain(args, new SparkILoop) } @@ -79,44 +81,50 @@ object Main extends Logging { } def createSparkSession(): SparkSession = { - val execUri = System.getenv("SPARK_EXECUTOR_URI") - conf.setIfMissing("spark.app.name", "Spark shell") - // SparkContext will detect this configuration and register it with the RpcEnv's - // file server, setting spark.repl.class.uri to the actual URI for executors to - // use. This is sort of ugly but since executors are started as part of SparkContext - // initialization in certain cases, there's an initialization order issue that prevents - // this from being set after SparkContext is instantiated. - conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) - if (execUri != null) { - conf.set("spark.executor.uri", execUri) - } - if (System.getenv("SPARK_HOME") != null) { - conf.setSparkHome(System.getenv("SPARK_HOME")) - } + try { + val execUri = System.getenv("SPARK_EXECUTOR_URI") + conf.setIfMissing("spark.app.name", "Spark shell") + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) + if (execUri != null) { + conf.set("spark.executor.uri", execUri) + } + if (System.getenv("SPARK_HOME") != null) { + conf.setSparkHome(System.getenv("SPARK_HOME")) + } - val builder = SparkSession.builder.config(conf) - if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { - if (SparkSession.hiveClassesArePresent) { - // In the case that the property is not set at all, builder's config - // does not have this value set to 'hive' yet. The original default - // behavior is that when there are hive classes, we use hive catalog. - sparkSession = builder.enableHiveSupport().getOrCreate() - logInfo("Created Spark session with Hive support") + val builder = SparkSession.builder.config(conf) + if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { + if (SparkSession.hiveClassesArePresent) { + // In the case that the property is not set at all, builder's config + // does not have this value set to 'hive' yet. The original default + // behavior is that when there are hive classes, we use hive catalog. + sparkSession = builder.enableHiveSupport().getOrCreate() + logInfo("Created Spark session with Hive support") + } else { + // Need to change it back to 'in-memory' if no hive classes are found + // in the case that the property is set to hive in spark-defaults.conf + builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + sparkSession = builder.getOrCreate() + logInfo("Created Spark session") + } } else { - // Need to change it back to 'in-memory' if no hive classes are found - // in the case that the property is set to hive in spark-defaults.conf - builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + // In the case that the property is set but not to 'hive', the internal + // default is 'in-memory'. So the sparkSession will use in-memory catalog. sparkSession = builder.getOrCreate() logInfo("Created Spark session") } - } else { - // In the case that the property is set but not to 'hive', the internal - // default is 'in-memory'. So the sparkSession will use in-memory catalog. - sparkSession = builder.getOrCreate() - logInfo("Created Spark session") + sparkContext = sparkSession.sparkContext + sparkSession + } catch { + case e: Exception if isShellSession => + logError("Failed to initialize Spark session.", e) + sys.exit(1) } - sparkContext = sparkSession.sparkContext - sparkSession } } diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index a62f271273465..920f0f6ebf2c8 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -47,6 +47,12 @@ test
      + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + + io.fabric8 kubernetes-client @@ -77,6 +83,12 @@ + + com.squareup.okhttp3 + okhttp + 3.8.1 + + org.mockito mockito-core @@ -84,9 +96,9 @@ - com.squareup.okhttp3 - okhttp - 3.8.1 + org.jmock + jmock-junit4 + test diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 4086970ffb256..f9a77e71ad618 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -117,6 +117,28 @@ private[spark] object Config extends Logging { .stringConf .createWithDefault("spark") + val KUBERNETES_PYSPARK_PY_FILES = + ConfigBuilder("spark.kubernetes.python.pyFiles") + .doc("The PyFiles that are distributed via client arguments") + .internal() + .stringConf + .createOptional + + val KUBERNETES_PYSPARK_MAIN_APP_RESOURCE = + ConfigBuilder("spark.kubernetes.python.mainAppResource") + .doc("The main app resource for pyspark jobs") + .internal() + .stringConf + .createOptional + + val KUBERNETES_PYSPARK_APP_ARGS = + ConfigBuilder("spark.kubernetes.python.appArgs") + .doc("The app arguments for PySpark Jobs") + .internal() + .stringConf + .createOptional + + val KUBERNETES_ALLOCATION_BATCH_SIZE = ConfigBuilder("spark.kubernetes.allocation.batch.size") .doc("Number of pods to launch at once in each round of executor allocation.") @@ -154,6 +176,41 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") + val KUBERNETES_EXECUTOR_API_POLLING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.apiPollingInterval") + .doc("Interval between polls against the Kubernetes API server to inspect the " + + "state of executors.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"API server polling interval must be a" + + " positive time value.") + .createWithDefaultString("30s") + + val KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.eventProcessingInterval") + .doc("Interval between successive inspection of executor events sent from the" + + " Kubernetes API.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"Event processing interval must be a positive" + + " time value.") + .createWithDefaultString("1s") + + val MEMORY_OVERHEAD_FACTOR = + ConfigBuilder("spark.kubernetes.memoryOverheadFactor") + .doc("This sets the Memory Overhead Factor that will allocate memory to non-JVM jobs " + + "which in the case of JVM tasks will default to 0.10 and 0.40 for non-JVM jobs") + .doubleConf + .checkValue(mem_overhead => mem_overhead >= 0 && mem_overhead < 1, + "Ensure that memory overhead is a double between 0 --> 1.0") + .createWithDefault(0.1) + + val PYSPARK_MAJOR_PYTHON_VERSION = + ConfigBuilder("spark.kubernetes.pyspark.pythonversion") + .doc("This sets the major Python version. Either 2 or 3. (Python2 or Python3)") + .stringConf + .checkValue(pv => List("2", "3").contains(pv), + "Ensure that major Python version is either Python2 or Python3") + .createWithDefault("2") + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" @@ -162,10 +219,24 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." + val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef." + val KUBERNETES_DRIVER_VOLUMES_PREFIX = "spark.kubernetes.driver.volumes." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." + val KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX = "spark.kubernetes.executor.secretKeyRef." + val KUBERNETES_EXECUTOR_VOLUMES_PREFIX = "spark.kubernetes.executor.volumes." + + val KUBERNETES_VOLUMES_HOSTPATH_TYPE = "hostPath" + val KUBERNETES_VOLUMES_PVC_TYPE = "persistentVolumeClaim" + val KUBERNETES_VOLUMES_EMPTYDIR_TYPE = "emptyDir" + val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path" + val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly" + val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path" + val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName" + val KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY = "options.medium" + val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 8da5f24044aad..5ecdd3a04d77b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -25,9 +25,6 @@ private[spark] object Constants { val SPARK_POD_DRIVER_ROLE = "driver" val SPARK_POD_EXECUTOR_ROLE = "executor" - // Annotations - val SPARK_APP_NAME_ANNOTATION = "spark-app-name" - // Credentials secrets val DRIVER_CREDENTIALS_SECRETS_BASE_DIR = "/mnt/secrets/spark-kubernetes-credentials" @@ -50,17 +47,14 @@ private[spark] object Constants { val DEFAULT_BLOCKMANAGER_PORT = 7079 val DRIVER_PORT_NAME = "driver-rpc-port" val BLOCK_MANAGER_PORT_NAME = "blockmanager" - val EXECUTOR_PORT_NAME = "executor" // Environment Variables - val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT" val ENV_DRIVER_URL = "SPARK_DRIVER_URL" val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES" val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY" val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" - val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" val ENV_CLASSPATH = "SPARK_CLASSPATH" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" @@ -71,9 +65,14 @@ private[spark] object Constants { val SPARK_CONF_FILE_NAME = "spark.properties" val SPARK_CONF_PATH = s"$SPARK_CONF_DIR_INTERNAL/$SPARK_CONF_FILE_NAME" + // BINDINGS + val ENV_PYSPARK_PRIMARY = "PYSPARK_PRIMARY" + val ENV_PYSPARK_FILES = "PYSPARK_FILES" + val ENV_PYSPARK_ARGS = "PYSPARK_APP_ARGS" + val ENV_PYSPARK_MAJOR_PYTHON_VERSION = "PYSPARK_MAJOR_PYTHON_VERSION" + // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" val DRIVER_CONTAINER_NAME = "spark-kubernetes-driver" - val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN_MIB = 384L } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 77b634ddfabcc..51d205fdb68d1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -16,14 +16,17 @@ */ package org.apache.spark.deploy.k8s +import scala.collection.mutable + import io.fabric8.kubernetes.api.model.{LocalObjectReference, LocalObjectReferenceBuilder, Pod} import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.{JavaMainAppResource, MainAppResource} +import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config.ConfigEntry + private[spark] sealed trait KubernetesRoleSpecificConf /* @@ -54,7 +57,10 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleLabels: Map[String, String], roleAnnotations: Map[String, String], roleSecretNamesToMountPaths: Map[String, String], - roleEnvs: Map[String, String]) { + roleSecretEnvNamesToKeyRefs: Map[String, String], + roleEnvs: Map[String, String], + roleVolumes: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]], + sparkFiles: Seq[String]) { def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) @@ -63,10 +69,14 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( .map(str => str.split(",").toSeq) .getOrElse(Seq.empty[String]) - def sparkFiles(): Seq[String] = sparkConf - .getOption("spark.files") - .map(str => str.split(",").toSeq) - .getOrElse(Seq.empty[String]) + def pyFiles(): Option[String] = sparkConf + .get(KUBERNETES_PYSPARK_PY_FILES) + + def pySparkMainResource(): Option[String] = sparkConf + .get(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE) + + def pySparkPythonVersion(): String = sparkConf + .get(PYSPARK_MAJOR_PYTHON_VERSION) def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) @@ -101,17 +111,30 @@ private[spark] object KubernetesConf { appId: String, mainAppResource: Option[MainAppResource], mainClass: String, - appArgs: Array[String]): KubernetesConf[KubernetesDriverSpecificConf] = { + appArgs: Array[String], + maybePyFiles: Option[String]): KubernetesConf[KubernetesDriverSpecificConf] = { val sparkConfWithMainAppJar = sparkConf.clone() + val additionalFiles = mutable.ArrayBuffer.empty[String] mainAppResource.foreach { - case JavaMainAppResource(res) => - val previousJars = sparkConf - .getOption("spark.jars") - .map(_.split(",")) - .getOrElse(Array.empty) - if (!previousJars.contains(res)) { - sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) - } + case JavaMainAppResource(res) => + val previousJars = sparkConf + .getOption("spark.jars") + .map(_.split(",")) + .getOrElse(Array.empty) + if (!previousJars.contains(res)) { + sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) + } + // The function of this outer match is to account for multiple nonJVM + // bindings that will all have increased MEMORY_OVERHEAD_FACTOR to 0.4 + case nonJVM: NonJVMResource => + nonJVM match { + case PythonMainAppResource(res) => + additionalFiles += res + maybePyFiles.foreach{maybePyFiles => + additionalFiles.appendAll(maybePyFiles.split(","))} + sparkConfWithMainAppJar.set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, res) + } + sparkConfWithMainAppJar.setIfMissing(MEMORY_OVERHEAD_FACTOR, 0.4) } val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( @@ -129,8 +152,21 @@ private[spark] object KubernetesConf { sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + val driverSecretEnvNamesToKeyRefs = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX) val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + val driverVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_DRIVER_VOLUMES_PREFIX).map(_.get) + // Also parse executor volumes in order to verify configuration + // before the driver pod is created + KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) + + val sparkFiles = sparkConf + .getOption("spark.files") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) ++ additionalFiles KubernetesConf( sparkConfWithMainAppJar, @@ -140,7 +176,10 @@ private[spark] object KubernetesConf { driverLabels, driverAnnotations, driverSecretNamesToMountPaths, - driverEnvs) + driverSecretEnvNamesToKeyRefs, + driverEnvs, + driverVolumes, + sparkFiles) } def createExecutorConf( @@ -167,9 +206,13 @@ private[spark] object KubernetesConf { executorCustomLabels val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - val executorSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + val executorMountSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX) val executorEnv = sparkConf.getExecutorEnv.toMap + val executorVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) KubernetesConf( sparkConf.clone(), @@ -178,7 +221,10 @@ private[spark] object KubernetesConf { appId, executorLabels, executorAnnotations, - executorSecrets, - executorEnv) + executorMountSecrets, + executorEnvSecrets, + executorEnv, + executorVolumes, + Seq.empty[String]) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index ee629068ad90d..66fff267545dc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,8 +16,6 @@ */ package org.apache.spark.deploy.k8s -import io.fabric8.kubernetes.api.model.LocalObjectReference - import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -52,7 +50,7 @@ private[spark] object KubernetesUtils { } } - private def resolveFileUri(uri: String): String = { + def resolveFileUri(uri: String): String = { val fileUri = Utils.resolveURI(uri) val fileScheme = Option(fileUri.getScheme).getOrElse("file") fileScheme match { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala new file mode 100644 index 0000000000000..b1762d1efe2ea --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +private[spark] sealed trait KubernetesVolumeSpecificConf + +private[spark] case class KubernetesHostPathVolumeConf( + hostPath: String) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesPVCVolumeConf( + claimName: String) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesEmptyDirVolumeConf( + medium: Option[String], + sizeLimit: Option[String]) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesVolumeSpec[T <: KubernetesVolumeSpecificConf]( + volumeName: String, + mountPath: String, + mountReadOnly: Boolean, + volumeConf: T) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala new file mode 100644 index 0000000000000..713df5fffc3a2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.util.NoSuchElementException + +import scala.util.{Failure, Success, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ + +private[spark] object KubernetesVolumeUtils { + /** + * Extract Spark volume configuration properties with a given name prefix. + * + * @param sparkConf Spark configuration + * @param prefix the given property name prefix + * @return a Map storing with volume name as key and spec as value + */ + def parseVolumesWithPrefix( + sparkConf: SparkConf, + prefix: String): Iterable[Try[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]]] = { + val properties = sparkConf.getAllWithPrefix(prefix).toMap + + getVolumeTypesAndNames(properties).map { case (volumeType, volumeName) => + val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY" + val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" + + for { + path <- properties.getTry(pathKey) + volumeConf <- parseVolumeSpecificConf(properties, volumeType, volumeName) + } yield KubernetesVolumeSpec( + volumeName = volumeName, + mountPath = path, + mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), + volumeConf = volumeConf + ) + } + } + + /** + * Get unique pairs of volumeType and volumeName, + * assuming options are formatted in this way: + * `volumeType`.`volumeName`.`property` = `value` + * @param properties flat mapping of property names to values + * @return Set[(volumeType, volumeName)] + */ + private def getVolumeTypesAndNames( + properties: Map[String, String] + ): Set[(String, String)] = { + properties.keys.flatMap { k => + k.split('.').toList match { + case tpe :: name :: _ => Some((tpe, name)) + case _ => None + } + }.toSet + } + + private def parseVolumeSpecificConf( + options: Map[String, String], + volumeType: String, + volumeName: String): Try[KubernetesVolumeSpecificConf] = { + volumeType match { + case KUBERNETES_VOLUMES_HOSTPATH_TYPE => + val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY" + for { + path <- options.getTry(pathKey) + } yield KubernetesHostPathVolumeConf(path) + + case KUBERNETES_VOLUMES_PVC_TYPE => + val claimNameKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY" + for { + claimName <- options.getTry(claimNameKey) + } yield KubernetesPVCVolumeConf(claimName) + + case KUBERNETES_VOLUMES_EMPTYDIR_TYPE => + val mediumKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY" + val sizeLimitKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY" + Success(KubernetesEmptyDirVolumeConf(options.get(mediumKey), options.get(sizeLimitKey))) + + case _ => + Failure(new RuntimeException(s"Kubernetes Volume type `$volumeType` is not supported")) + } + } + + /** + * Convenience wrapper to accumulate key lookup errors + */ + implicit private class MapOps[A, B](m: Map[A, B]) { + def getTry(key: A): Try[B] = { + m + .get(key) + .fold[Try[B]](Failure(new NoSuchElementException(key.toString)))(Success(_)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 07bdccbe0479d..7e67b51de6e04 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -19,14 +19,14 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ import scala.collection.mutable -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config._ -import org.apache.spark.launcher.SparkLauncher private[spark] class BasicDriverFeatureStep( conf: KubernetesConf[KubernetesDriverSpecificConf]) @@ -48,7 +48,8 @@ private[spark] class BasicDriverFeatureStep( private val driverMemoryMiB = conf.get(DRIVER_MEMORY) private val memoryOverheadMiB = conf .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + .getOrElse(math.max((conf.get(MEMORY_OVERHEAD_FACTOR) * driverMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB override def configurePod(pod: SparkPod): SparkPod = { @@ -88,13 +89,6 @@ private[spark] class BasicDriverFeatureStep( .addToRequests("memory", driverMemoryQuantity) .addToLimits("memory", driverMemoryQuantity) .endResources() - .addToArgs("driver") - .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", conf.roleSpecificConf.mainClass) - // The user application jar is merged into the spark.jars list and managed through that - // property, so there is no need to reference it explicitly here. - .addToArgs(SparkLauncher.NO_RESOURCE) - .addToArgs(conf.roleSpecificConf.appArgs: _*) .build() val driverPod = new PodBuilder(pod.pod) @@ -109,6 +103,7 @@ private[spark] class BasicDriverFeatureStep( .addToImagePullSecrets(conf.imagePullSecrets(): _*) .endSpec() .build() + SparkPod(driverPod, driverContainer) } @@ -122,7 +117,7 @@ private[spark] class BasicDriverFeatureStep( val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath( conf.sparkJars()) val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath( - conf.sparkFiles()) + conf.sparkFiles) if (resolvedSparkJars.nonEmpty) { additionalProps.put("spark.jars", resolvedSparkJars.mkString(",")) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index d22097587aafe..abaeff0313a79 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -18,10 +18,10 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} @@ -54,7 +54,8 @@ private[spark] class BasicExecutorFeatureStep( private val memoryOverheadMiB = kubernetesConf .get(EXECUTOR_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + .getOrElse(math.max( + (kubernetesConf.get(MEMORY_OVERHEAD_FACTOR) * executorMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB @@ -89,7 +90,9 @@ private[spark] class BasicExecutorFeatureStep( val executorExtraJavaOptionsEnv = kubernetesConf .get(EXECUTOR_JAVA_OPTIONS) .map { opts => - val delimitedOpts = Utils.splitCommandString(opts) + val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId, + kubernetesConf.roleSpecificConf.executorId) + val delimitedOpts = Utils.splitCommandString(subsOpts) delimitedOpts.zipWithIndex.map { case (opt, index) => new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() @@ -170,6 +173,7 @@ private[spark] class BasicExecutorFeatureStep( .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) .endSpec() .build() + SparkPod(executorPod, containerWithLimitCores) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala new file mode 100644 index 0000000000000..03ff7d48420ff --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class EnvSecretsFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val addedEnvSecrets = kubernetesConf + .roleSecretEnvNamesToKeyRefs + .map{ case (envName, keyRef) => + // Keyref parts + val keyRefParts = keyRef.split(":") + require(keyRefParts.size == 2, "SecretKeyRef must be in the form name:key.") + val name = keyRefParts(0) + val key = keyRefParts(1) + new EnvVarBuilder() + .withName(envName) + .withNewValueFrom() + .withNewSecretKeyRef() + .withKey(key) + .withName(name) + .endSecretKeyRef() + .endValueFrom() + .build() + } + + val containerWithEnvVars = new ContainerBuilder(pod.container) + .addAllToEnv(addedEnvSecrets.toSeq.asJava) + .build() + SparkPod(pod.pod, containerWithEnvVars) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala new file mode 100644 index 0000000000000..70b307303d149 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.nio.file.Paths +import java.util.UUID + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class LocalDirsFeatureStep( + conf: KubernetesConf[_ <: KubernetesRoleSpecificConf], + defaultLocalDir: String = s"/var/data/spark-${UUID.randomUUID}") + extends KubernetesFeatureConfigStep { + + // Cannot use Utils.getConfiguredLocalDirs because that will default to the Java system + // property - we want to instead default to mounting an emptydir volume that doesn't already + // exist in the image. + // We could make utils.getConfiguredLocalDirs opinionated about Kubernetes, as it is already + // a bit opinionated about YARN and Mesos. + private val resolvedLocalDirs = Option(conf.sparkConf.getenv("SPARK_LOCAL_DIRS")) + .orElse(conf.getOption("spark.local.dir")) + .getOrElse(defaultLocalDir) + .split(",") + + override def configurePod(pod: SparkPod): SparkPod = { + val localDirVolumes = resolvedLocalDirs + .zipWithIndex + .map { case (localDir, index) => + new VolumeBuilder() + .withName(s"spark-local-dir-${index + 1}") + .withNewEmptyDir() + .endEmptyDir() + .build() + } + val localDirVolumeMounts = localDirVolumes + .zip(resolvedLocalDirs) + .map { case (localDirVolume, localDirPath) => + new VolumeMountBuilder() + .withName(localDirVolume.getName) + .withMountPath(localDirPath) + .build() + } + val podWithLocalDirVolumes = new PodBuilder(pod.pod) + .editSpec() + .addToVolumes(localDirVolumes: _*) + .endSpec() + .build() + val containerWithLocalDirVolumeMounts = new ContainerBuilder(pod.container) + .addNewEnv() + .withName("SPARK_LOCAL_DIRS") + .withValue(resolvedLocalDirs.mkString(",")) + .endEnv() + .addToVolumeMounts(localDirVolumeMounts: _*) + .build() + SparkPod(podWithLocalDirVolumes, containerWithLocalDirVolumeMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala new file mode 100644 index 0000000000000..bb0e2b3128efd --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model._ + +import org.apache.spark.deploy.k8s._ + +private[spark] class MountVolumesFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + + override def configurePod(pod: SparkPod): SparkPod = { + val (volumeMounts, volumes) = constructVolumes(kubernetesConf.roleVolumes).unzip + + val podWithVolumes = new PodBuilder(pod.pod) + .editSpec() + .addToVolumes(volumes.toSeq: _*) + .endSpec() + .build() + + val containerWithVolumeMounts = new ContainerBuilder(pod.container) + .addToVolumeMounts(volumeMounts.toSeq: _*) + .build() + + SparkPod(podWithVolumes, containerWithVolumeMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty + + private def constructVolumes( + volumeSpecs: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]] + ): Iterable[(VolumeMount, Volume)] = { + volumeSpecs.map { spec => + val volumeMount = new VolumeMountBuilder() + .withMountPath(spec.mountPath) + .withReadOnly(spec.mountReadOnly) + .withName(spec.volumeName) + .build() + + val volumeBuilder = spec.volumeConf match { + case KubernetesHostPathVolumeConf(hostPath) => + new VolumeBuilder() + .withHostPath(new HostPathVolumeSource(hostPath)) + + case KubernetesPVCVolumeConf(claimName) => + new VolumeBuilder() + .withPersistentVolumeClaim( + new PersistentVolumeClaimVolumeSource(claimName, spec.mountReadOnly)) + + case KubernetesEmptyDirVolumeConf(medium, sizeLimit) => + new VolumeBuilder() + .withEmptyDir( + new EmptyDirVolumeSource(medium.getOrElse(""), + new Quantity(sizeLimit.orNull))) + } + + val volume = volumeBuilder.withName(spec.volumeName).build() + + (volumeMount, volume) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala new file mode 100644 index 0000000000000..f52ec9fdc677e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants.SPARK_CONF_PATH +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep +import org.apache.spark.launcher.SparkLauncher + +private[spark] class JavaDriverFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val withDriverArgs = new ContainerBuilder(pod.container) + .addToArgs("driver") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", kubernetesConf.roleSpecificConf.mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + .addToArgs(kubernetesConf.roleSpecificConf.appArgs: _*) + .build() + SparkPod(pod.pod, withDriverArgs) + } + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala new file mode 100644 index 0000000000000..c20bcac1f8987 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep + +private[spark] class PythonDriverFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val roleConf = kubernetesConf.roleSpecificConf + require(roleConf.mainAppResource.isDefined, "PySpark Main Resource must be defined") + val maybePythonArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map( + pyArgs => + new EnvVarBuilder() + .withName(ENV_PYSPARK_ARGS) + .withValue(pyArgs.mkString(",")) + .build()) + val maybePythonFiles = kubernetesConf.pyFiles().map( + // Dilineation by ":" is to append the PySpark Files to the PYTHONPATH + // of the respective PySpark pod + pyFiles => + new EnvVarBuilder() + .withName(ENV_PYSPARK_FILES) + .withValue(KubernetesUtils.resolveFileUrisAndPath(pyFiles.split(",")) + .mkString(":")) + .build()) + val envSeq = + Seq(new EnvVarBuilder() + .withName(ENV_PYSPARK_PRIMARY) + .withValue(KubernetesUtils.resolveFileUri(kubernetesConf.pySparkMainResource().get)) + .build(), + new EnvVarBuilder() + .withName(ENV_PYSPARK_MAJOR_PYTHON_VERSION) + .withValue(kubernetesConf.pySparkPythonVersion()) + .build()) + val pythonEnvs = envSeq ++ + maybePythonArgs.toSeq ++ + maybePythonFiles.toSeq + + val withPythonPrimaryContainer = new ContainerBuilder(pod.container) + .addAllToEnv(pythonEnvs.asJava) + .addToArgs("driver-py") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", roleConf.mainClass) + .build() + + SparkPod(pod.pod, withPythonPrimaryContainer) + } + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index a97f5650fb869..eaff47205dbbc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -39,11 +39,13 @@ import org.apache.spark.util.Utils * @param mainAppResource the main application resource if any * @param mainClass the main class of the application to run * @param driverArgs arguments to the driver + * @param maybePyFiles additional Python files via --py-files */ private[spark] case class ClientArguments( mainAppResource: Option[MainAppResource], mainClass: String, - driverArgs: Array[String]) + driverArgs: Array[String], + maybePyFiles: Option[String]) private[spark] object ClientArguments { @@ -51,10 +53,15 @@ private[spark] object ClientArguments { var mainAppResource: Option[MainAppResource] = None var mainClass: Option[String] = None val driverArgs = mutable.ArrayBuffer.empty[String] + var maybePyFiles : Option[String] = None args.sliding(2, 2).toList.foreach { case Array("--primary-java-resource", primaryJavaResource: String) => mainAppResource = Some(JavaMainAppResource(primaryJavaResource)) + case Array("--primary-py-file", primaryPythonResource: String) => + mainAppResource = Some(PythonMainAppResource(primaryPythonResource)) + case Array("--other-py-files", pyFiles: String) => + maybePyFiles = Some(pyFiles) case Array("--main-class", clazz: String) => mainClass = Some(clazz) case Array("--arg", arg: String) => @@ -69,7 +76,8 @@ private[spark] object ClientArguments { ClientArguments( mainAppResource, mainClass.get, - driverArgs.toArray) + driverArgs.toArray, + maybePyFiles) } } @@ -206,6 +214,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val kubernetesResourceNamePrefix = { s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") } + sparkConf.set(KUBERNETES_PYSPARK_PY_FILES, clientArguments.maybePyFiles.getOrElse("")) val kubernetesConf = KubernetesConf.createDriverConf( sparkConf, appName, @@ -213,7 +222,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { kubernetesAppId, clientArguments.mainAppResource, clientArguments.mainClass, - clientArguments.driverArgs) + clientArguments.driverArgs, + clientArguments.maybePyFiles) val builder = new KubernetesDriverBuilder val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index c7579ed8cb689..7208e3d377593 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,8 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -29,17 +30,52 @@ private[spark] class KubernetesDriverBuilder( new DriverServiceFeatureStep(_), provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => MountSecretsFeatureStep) = - new MountSecretsFeatureStep(_)) { + new MountSecretsFeatureStep(_), + provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => EnvSecretsFeatureStep) = + new EnvSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = + new LocalDirsFeatureStep(_), + provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountVolumesFeatureStep) = + new MountVolumesFeatureStep(_), + provideJavaStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => JavaDriverFeatureStep) = + new JavaDriverFeatureStep(_), + providePythonStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => PythonDriverFeatureStep) = + new PythonDriverFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { val baseFeatures = Seq( provideBasicStep(kubernetesConf), provideCredentialsStep(kubernetesConf), - provideServiceStep(kubernetesConf)) - val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) - } else baseFeatures + provideServiceStep(kubernetesConf), + provideLocalDirsStep(kubernetesConf)) + + val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Seq(provideSecretsStep(kubernetesConf)) + } else Nil + val envSecretFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Seq(provideEnvSecretsStep(kubernetesConf)) + } else Nil + val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + Seq(provideVolumesStep(kubernetesConf)) + } else Nil + + val bindingsStep = kubernetesConf.roleSpecificConf.mainAppResource.map { + case JavaMainAppResource(_) => + provideJavaStep(kubernetesConf) + case PythonMainAppResource(_) => + providePythonStep(kubernetesConf)} + .getOrElse(provideJavaStep(kubernetesConf)) + + val allFeatures = (baseFeatures :+ bindingsStep) ++ + secretFeature ++ envSecretFeature ++ volumesFeature var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala index cca9f4627a1f6..cbe081ae35683 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala @@ -18,4 +18,9 @@ package org.apache.spark.deploy.k8s.submit private[spark] sealed trait MainAppResource +private[spark] sealed trait NonJVMResource + private[spark] case class JavaMainAppResource(primaryResource: String) extends MainAppResource + +private[spark] case class PythonMainAppResource(primaryResource: String) + extends MainAppResource with NonJVMResource diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala new file mode 100644 index 0000000000000..83daddf714489 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +sealed trait ExecutorPodState { + def pod: Pod +} + +case class PodRunning(pod: Pod) extends ExecutorPodState + +case class PodPending(pod: Pod) extends ExecutorPodState + +sealed trait FinalPodState extends ExecutorPodState + +case class PodSucceeded(pod: Pod) extends FinalPodState + +case class PodFailed(pod: Pod) extends FinalPodState + +case class PodDeleted(pod: Pod) extends FinalPodState + +case class PodUnknown(pod: Pod) extends ExecutorPodState diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala new file mode 100644 index 0000000000000..5a143ad3600fd --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +import io.fabric8.kubernetes.api.model.PodBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, Utils} + +private[spark] class ExecutorPodsAllocator( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + clock: Clock) extends Logging { + + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) + + private val totalExpectedExecutors = new AtomicInteger(0) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000) + + private val kubernetesDriverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(throw new SparkException("Must specify the driver pod name")) + + private val driverPod = kubernetesClient.pods() + .withName(kubernetesDriverPodName) + .get() + + // Executor IDs that have been requested from Kubernetes but have not been detected in any + // snapshot yet. Mapped to the timestamp when they were created. + private val newlyCreatedExecutors = mutable.Map.empty[Long, Long] + + def start(applicationId: String): Unit = { + snapshotsStore.addSubscriber(podAllocationDelay) { + onNewSnapshots(applicationId, _) + } + } + + def setTotalExpectedExecutors(total: Int): Unit = totalExpectedExecutors.set(total) + + private def onNewSnapshots(applicationId: String, snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + newlyCreatedExecutors --= snapshots.flatMap(_.executorPods.keys) + // For all executors we've created against the API but have not seen in a snapshot + // yet - check the current time. If the current time has exceeded some threshold, + // assume that the pod was either never created (the API server never properly + // handled the creation request), or the API server created the pod but we missed + // both the creation and deletion events. In either case, delete the missing pod + // if possible, and mark such a pod to be rescheduled below. + newlyCreatedExecutors.foreach { case (execId, timeCreated) => + val currentTime = clock.getTimeMillis() + if (currentTime - timeCreated > podCreationTimeout) { + logWarning(s"Executor with id $execId was not detected in the Kubernetes" + + s" cluster after $podCreationTimeout milliseconds despite the fact that a" + + " previous allocation attempt tried to create it. The executor may have been" + + " deleted but the application missed the deletion event.") + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withLabel(SPARK_EXECUTOR_ID_LABEL, execId.toString) + .delete() + } + newlyCreatedExecutors -= execId + } else { + logDebug(s"Executor with id $execId was not found in the Kubernetes cluster since it" + + s" was created ${currentTime - timeCreated} milliseconds ago.") + } + } + + if (snapshots.nonEmpty) { + // Only need to examine the cluster as of the latest snapshot, the "current" state, to see if + // we need to allocate more executors or not. + val latestSnapshot = snapshots.last + val currentRunningExecutors = latestSnapshot.executorPods.values.count { + case PodRunning(_) => true + case _ => false + } + val currentPendingExecutors = latestSnapshot.executorPods.values.count { + case PodPending(_) => true + case _ => false + } + val currentTotalExpectedExecutors = totalExpectedExecutors.get + logDebug(s"Currently have $currentRunningExecutors running executors and" + + s" $currentPendingExecutors pending executors. $newlyCreatedExecutors executors" + + s" have been requested but are pending appearance in the cluster.") + if (newlyCreatedExecutors.isEmpty + && currentPendingExecutors == 0 + && currentRunningExecutors < currentTotalExpectedExecutors) { + val numExecutorsToAllocate = math.min( + currentTotalExpectedExecutors - currentRunningExecutors, podAllocationSize) + logInfo(s"Going to request $numExecutorsToAllocate executors from Kubernetes.") + for ( _ <- 0 until numExecutorsToAllocate) { + val newExecutorId = EXECUTOR_ID_COUNTER.incrementAndGet() + val executorConf = KubernetesConf.createExecutorConf( + conf, + newExecutorId.toString, + applicationId, + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + kubernetesClient.pods().create(podWithAttachedContainer) + newlyCreatedExecutors(newExecutorId) = clock.getTimeMillis() + logDebug(s"Requested executor with id $newExecutorId from Kubernetes.") + } + } else if (currentRunningExecutors >= currentTotalExpectedExecutors) { + // TODO handle edge cases if we end up with more running executors than expected. + logDebug("Current number of running executors is equal to the number of requested" + + " executors. Not scaling up further.") + } else if (newlyCreatedExecutors.nonEmpty || currentPendingExecutors != 0) { + logDebug(s"Still waiting for ${newlyCreatedExecutors.size + currentPendingExecutors}" + + s" executors to begin running before requesting for more executors. # of executors in" + + s" pending status in the cluster: $currentPendingExecutors. # of executors that we have" + + s" created but we have not observed as being present in the cluster yet:" + + s" ${newlyCreatedExecutors.size}.") + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala new file mode 100644 index 0000000000000..b28d93990313e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.Cache +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsLifecycleManager( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + // Use a best-effort to track which executors have been removed already. It's not generally + // job-breaking if we remove executors more than once but it's ideal if we make an attempt + // to avoid doing so. Expire cache entries so that this data structure doesn't grow beyond + // bounds. + removedExecutorsCache: Cache[java.lang.Long, java.lang.Long]) extends Logging { + + import ExecutorPodsLifecycleManager._ + + private val eventProcessingInterval = conf.get(KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL) + + def start(schedulerBackend: KubernetesClusterSchedulerBackend): Unit = { + snapshotsStore.addSubscriber(eventProcessingInterval) { + onNewSnapshots(schedulerBackend, _) + } + } + + private def onNewSnapshots( + schedulerBackend: KubernetesClusterSchedulerBackend, + snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + val execIdsRemovedInThisRound = mutable.HashSet.empty[Long] + snapshots.foreach { snapshot => + snapshot.executorPods.foreach { case (execId, state) => + state match { + case deleted@PodDeleted(_) => + logDebug(s"Snapshot reported deleted executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + removeExecutorFromSpark(schedulerBackend, deleted, execId) + execIdsRemovedInThisRound += execId + case failed@PodFailed(_) => + logDebug(s"Snapshot reported failed executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + onFinalNonDeletedState(failed, execId, schedulerBackend, execIdsRemovedInThisRound) + case succeeded@PodSucceeded(_) => + logDebug(s"Snapshot reported succeeded executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}. Note that succeeded executors are" + + s" unusual unless Spark specifically informed the executor to exit.") + onFinalNonDeletedState(succeeded, execId, schedulerBackend, execIdsRemovedInThisRound) + case _ => + } + } + } + + // Reconcile the case where Spark claims to know about an executor but the corresponding pod + // is missing from the cluster. This would occur if we miss a deletion event and the pod + // transitions immediately from running io absent. We only need to check against the latest + // snapshot for this, and we don't do this for executors in the deleted executors cache or + // that we just removed in this round. + if (snapshots.nonEmpty) { + val latestSnapshot = snapshots.last + (schedulerBackend.getExecutorIds().map(_.toLong).toSet + -- latestSnapshot.executorPods.keySet + -- execIdsRemovedInThisRound).foreach { missingExecutorId => + if (removedExecutorsCache.getIfPresent(missingExecutorId) == null) { + val exitReasonMessage = s"The executor with ID $missingExecutorId was not found in the" + + s" cluster but we didn't get a reason why. Marking the executor as failed. The" + + s" executor may have been deleted but the driver missed the deletion event." + logDebug(exitReasonMessage) + val exitReason = ExecutorExited( + UNKNOWN_EXIT_CODE, + exitCausedByApp = false, + exitReasonMessage) + schedulerBackend.doRemoveExecutor(missingExecutorId.toString, exitReason) + execIdsRemovedInThisRound += missingExecutorId + } + } + } + logDebug(s"Removed executors with ids ${execIdsRemovedInThisRound.mkString(",")}" + + s" from Spark that were either found to be deleted or non-existent in the cluster.") + } + + private def onFinalNonDeletedState( + podState: FinalPodState, + execId: Long, + schedulerBackend: KubernetesClusterSchedulerBackend, + execIdsRemovedInRound: mutable.Set[Long]): Unit = { + removeExecutorFromK8s(podState.pod) + removeExecutorFromSpark(schedulerBackend, podState, execId) + execIdsRemovedInRound += execId + } + + private def removeExecutorFromK8s(updatedPod: Pod): Unit = { + // If deletion failed on a previous try, we can try again if resync informs us the pod + // is still around. + // Delete as best attempt - duplicate deletes will throw an exception but the end state + // of getting rid of the pod is what matters. + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withName(updatedPod.getMetadata.getName) + .delete() + } + } + + private def removeExecutorFromSpark( + schedulerBackend: KubernetesClusterSchedulerBackend, + podState: FinalPodState, + execId: Long): Unit = { + if (removedExecutorsCache.getIfPresent(execId) == null) { + removedExecutorsCache.put(execId, execId) + val exitReason = findExitReason(podState, execId) + schedulerBackend.doRemoveExecutor(execId.toString, exitReason) + } + } + + private def findExitReason(podState: FinalPodState, execId: Long): ExecutorExited = { + val exitCode = findExitCode(podState) + val (exitCausedByApp, exitMessage) = podState match { + case PodDeleted(_) => + (false, s"The executor with id $execId was deleted by a user or the framework.") + case _ => + val msg = exitReasonMessage(podState, execId, exitCode) + (true, msg) + } + ExecutorExited(exitCode, exitCausedByApp, exitMessage) + } + + private def exitReasonMessage(podState: FinalPodState, execId: Long, exitCode: Int) = { + val pod = podState.pod + s""" + |The executor with id $execId exited with exit code $exitCode. + |The API gave the following brief reason: ${pod.getStatus.getReason} + |The API gave the following message: ${pod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${pod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def findExitCode(podState: FinalPodState): Int = { + podState.pod.getStatus.getContainerStatuses.asScala.find { containerStatus => + containerStatus.getState.getTerminated != null + }.map { terminatedContainer => + terminatedContainer.getState.getTerminated.getExitCode.toInt + }.getOrElse(UNKNOWN_EXIT_CODE) + } +} + +private object ExecutorPodsLifecycleManager { + val UNKNOWN_EXIT_CODE = -1 +} + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala new file mode 100644 index 0000000000000..e77e604d00e0f --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.{Future, ScheduledExecutorService, TimeUnit} + +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.ThreadUtils + +private[spark] class ExecutorPodsPollingSnapshotSource( + conf: SparkConf, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + pollingExecutor: ScheduledExecutorService) extends Logging { + + private val pollingInterval = conf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + private var pollingFuture: Future[_] = _ + + def start(applicationId: String): Unit = { + require(pollingFuture == null, "Cannot start polling more than once.") + logDebug(s"Starting to check for executor pod state every $pollingInterval ms.") + pollingFuture = pollingExecutor.scheduleWithFixedDelay( + new PollRunnable(applicationId), pollingInterval, pollingInterval, TimeUnit.MILLISECONDS) + } + + def stop(): Unit = { + if (pollingFuture != null) { + pollingFuture.cancel(true) + pollingFuture = null + } + ThreadUtils.shutdown(pollingExecutor) + } + + private class PollRunnable(applicationId: String) extends Runnable { + override def run(): Unit = { + logDebug(s"Resynchronizing full executor pod state from Kubernetes.") + snapshotsStore.replaceSnapshot(kubernetesClient + .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .list() + .getItems + .asScala) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala new file mode 100644 index 0000000000000..26be918043412 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging + +/** + * An immutable view of the current executor pods that are running in the cluster. + */ +private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) { + + import ExecutorPodsSnapshot._ + + def withUpdate(updatedPod: Pod): ExecutorPodsSnapshot = { + val newExecutorPods = executorPods ++ toStatesByExecutorId(Seq(updatedPod)) + new ExecutorPodsSnapshot(newExecutorPods) + } +} + +object ExecutorPodsSnapshot extends Logging { + + def apply(executorPods: Seq[Pod]): ExecutorPodsSnapshot = { + ExecutorPodsSnapshot(toStatesByExecutorId(executorPods)) + } + + def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState]) + + private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = { + executorPods.map { pod => + (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod)) + }.toMap + } + + private def toState(pod: Pod): ExecutorPodState = { + if (isDeleted(pod)) { + PodDeleted(pod) + } else { + val phase = pod.getStatus.getPhase.toLowerCase + phase match { + case "pending" => + PodPending(pod) + case "running" => + PodRunning(pod) + case "failed" => + PodFailed(pod) + case "succeeded" => + PodSucceeded(pod) + case _ => + logWarning(s"Received unknown phase $phase for executor pod with name" + + s" ${pod.getMetadata.getName} in namespace ${pod.getMetadata.getNamespace}") + PodUnknown(pod) + } + } + } + + private def isDeleted(pod: Pod): Boolean = pod.getMetadata.getDeletionTimestamp != null +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala new file mode 100644 index 0000000000000..dd264332cf9e8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +private[spark] trait ExecutorPodsSnapshotsStore { + + def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) + + def stop(): Unit + + def updatePod(updatedPod: Pod): Unit + + def replaceSnapshot(newSnapshot: Seq[Pod]): Unit +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala new file mode 100644 index 0000000000000..5583b4617eeb2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent._ + +import io.fabric8.kubernetes.api.model.Pod +import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Controls the propagation of the Spark application's executor pods state to subscribers that + * react to that state. + *
      + * Roughly follows a producer-consumer model. Producers report states of executor pods, and these + * states are then published to consumers that can perform any actions in response to these states. + *
      + * Producers push updates in one of two ways. An incremental update sent by updatePod() represents + * a known new state of a single executor pod. A full sync sent by replaceSnapshot() indicates that + * the passed pods are all of the most up to date states of all executor pods for the application. + * The combination of the states of all executor pods for the application is collectively known as + * a snapshot. The store keeps track of the most up to date snapshot, and applies updates to that + * most recent snapshot - either by incrementally updating the snapshot with a single new pod state, + * or by replacing the snapshot entirely on a full sync. + *
      + * Consumers, or subscribers, register that they want to be informed about all snapshots of the + * executor pods. Every time the store replaces its most up to date snapshot from either an + * incremental update or a full sync, the most recent snapshot after the update is posted to the + * subscriber's buffer. Subscribers receive blocks of snapshots produced by the producers in + * time-windowed chunks. Each subscriber can choose to receive their snapshot chunks at different + * time intervals. + */ +private[spark] class ExecutorPodsSnapshotsStoreImpl(subscribersExecutor: ScheduledExecutorService) + extends ExecutorPodsSnapshotsStore { + + private val SNAPSHOT_LOCK = new Object() + + private val subscribers = mutable.Buffer.empty[SnapshotsSubscriber] + private val pollingTasks = mutable.Buffer.empty[Future[_]] + + @GuardedBy("SNAPSHOT_LOCK") + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber( + processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + val newSubscriber = SnapshotsSubscriber( + new LinkedBlockingQueue[ExecutorPodsSnapshot](), onNewSnapshots) + SNAPSHOT_LOCK.synchronized { + newSubscriber.snapshotsBuffer.add(currentSnapshot) + } + subscribers += newSubscriber + pollingTasks += subscribersExecutor.scheduleWithFixedDelay( + toRunnable(() => callSubscriber(newSubscriber)), + 0L, + processBatchIntervalMillis, + TimeUnit.MILLISECONDS) + } + + override def stop(): Unit = { + pollingTasks.foreach(_.cancel(true)) + ThreadUtils.shutdown(subscribersExecutor) + } + + override def updatePod(updatedPod: Pod): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + addCurrentSnapshotToSubscribers() + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + addCurrentSnapshotToSubscribers() + } + + private def addCurrentSnapshotToSubscribers(): Unit = { + subscribers.foreach { subscriber => + subscriber.snapshotsBuffer.add(currentSnapshot) + } + } + + private def callSubscriber(subscriber: SnapshotsSubscriber): Unit = { + Utils.tryLogNonFatalError { + val currentSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot].asJava + subscriber.snapshotsBuffer.drainTo(currentSnapshots) + subscriber.onNewSnapshots(currentSnapshots.asScala) + } + } + + private def toRunnable[T](runnable: () => Unit): Runnable = new Runnable { + override def run(): Unit = runnable() + } + + private case class SnapshotsSubscriber( + snapshotsBuffer: BlockingQueue[ExecutorPodsSnapshot], + onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala new file mode 100644 index 0000000000000..a6749a644e00c --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.Closeable + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsWatchSnapshotSource( + snapshotsStore: ExecutorPodsSnapshotsStore, + kubernetesClient: KubernetesClient) extends Logging { + + private var watchConnection: Closeable = _ + + def start(applicationId: String): Unit = { + require(watchConnection == null, "Cannot start the watcher twice.") + logDebug(s"Starting watch for pods with labels $SPARK_APP_ID_LABEL=$applicationId," + + s" $SPARK_ROLE_LABEL=$SPARK_POD_EXECUTOR_ROLE.") + watchConnection = kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .watch(new ExecutorPodsWatcher()) + } + + def stop(): Unit = { + if (watchConnection != null) { + Utils.tryLogNonFatalError { + watchConnection.close() + } + watchConnection = null + } + } + + private class ExecutorPodsWatcher extends Watcher[Pod] { + override def eventReceived(action: Action, pod: Pod): Unit = { + val podName = pod.getMetadata.getName + logDebug(s"Received executor pod update for pod named $podName, action $action") + snapshotsStore.updatePod(pod) + } + + override def onClose(e: KubernetesClientException): Unit = { + logWarning("Kubernetes client has been closed (this is expected if the application is" + + " shutting down.)", e) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index 0ea80dfbc0d97..de2a52bc7a0b8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -17,7 +17,9 @@ package org.apache.spark.scheduler.cluster.k8s import java.io.File +import java.util.concurrent.TimeUnit +import com.google.common.cache.CacheBuilder import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} @@ -26,7 +28,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SystemClock, ThreadUtils} private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { @@ -46,8 +48,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, Some(sc.conf.get(KUBERNETES_NAMESPACE)), @@ -56,17 +56,45 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val allocatorExecutor = ThreadUtils - .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( "kubernetes-executor-requests") + + val subscribersExecutor = ThreadUtils + .newDaemonThreadPoolScheduledExecutor( + "kubernetes-executor-snapshots-subscribers", 2) + val snapshotsStore = new ExecutorPodsSnapshotsStoreImpl(subscribersExecutor) + val removedExecutorsCache = CacheBuilder.newBuilder() + .expireAfterWrite(3, TimeUnit.MINUTES) + .build[java.lang.Long, java.lang.Long]() + val executorPodsLifecycleEventHandler = new ExecutorPodsLifecycleManager( + sc.conf, + new KubernetesExecutorBuilder(), + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + + val executorPodsAllocator = new ExecutorPodsAllocator( + sc.conf, new KubernetesExecutorBuilder(), kubernetesClient, snapshotsStore, new SystemClock()) + + val podsWatchEventSource = new ExecutorPodsWatchSnapshotSource( + snapshotsStore, + kubernetesClient) + + val eventsPollingExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "kubernetes-executor-pod-polling-sync") + val podsPollingEventSource = new ExecutorPodsPollingSnapshotSource( + sc.conf, kubernetesClient, snapshotsStore, eventsPollingExecutor) + new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - new KubernetesExecutorBuilder, kubernetesClient, - allocatorExecutor, - requestExecutorsService) + requestExecutorsService, + snapshotsStore, + executorPodsAllocator, + executorPodsLifecycleEventHandler, + podsWatchEventSource, + podsPollingEventSource) } override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index d86664c81071b..fa6dc2c479bbf 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -16,60 +16,32 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.io.Closeable -import java.net.InetAddress -import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit} -import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.ExecutorService -import io.fabric8.kubernetes.api.model._ -import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import scala.collection.JavaConverters._ -import scala.collection.mutable +import io.fabric8.kubernetes.client.KubernetesClient import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesConf -import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} -import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, - allocatorExecutor: ScheduledExecutorService, - requestExecutorsService: ExecutorService) + requestExecutorsService: ExecutorService, + snapshotsStore: ExecutorPodsSnapshotsStore, + podAllocator: ExecutorPodsAllocator, + lifecycleEventHandler: ExecutorPodsLifecycleManager, + watchEvents: ExecutorPodsWatchSnapshotSource, + pollEvents: ExecutorPodsPollingSnapshotSource) extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { - import KubernetesClusterSchedulerBackend._ - - private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) - private val RUNNING_EXECUTOR_PODS_LOCK = new Object - @GuardedBy("RUNNING_EXECUTOR_PODS_LOCK") - private val runningExecutorsToPods = new mutable.HashMap[String, Pod] - private val executorPodsByIPs = new ConcurrentHashMap[String, Pod]() - private val podsWithKnownExitReasons = new ConcurrentHashMap[String, ExecutorExited]() - private val disconnectedPodsByExecutorIdPendingRemoval = new ConcurrentHashMap[String, Pod]() - - private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) - - private val kubernetesDriverPodName = conf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(throw new SparkException("Must specify the driver pod name")) private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( requestExecutorsService) - private val driverPod = kubernetesClient.pods() - .inNamespace(kubernetesNamespace) - .withName(kubernetesDriverPodName) - .get() - protected override val minRegisteredRatio = if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { 0.8 @@ -77,372 +49,93 @@ private[spark] class KubernetesClusterSchedulerBackend( super.minRegisteredRatio } - private val executorWatchResource = new AtomicReference[Closeable] - private val totalExpectedExecutors = new AtomicInteger(0) - - private val driverUrl = RpcEndpointAddress( - conf.get("spark.driver.host"), - conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString - private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) - private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) - - private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) - - private val executorLostReasonCheckMaxAttempts = conf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - private val allocatorRunnable = new Runnable { - - // Maintains a map of executor id to count of checks performed to learn the loss reason - // for an executor. - private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] - - override def run(): Unit = { - handleDisconnectedExecutors() - - val executorsToAllocate = mutable.Map[String, Pod]() - val currentTotalRegisteredExecutors = totalRegisteredExecutors.get - val currentTotalExpectedExecutors = totalExpectedExecutors.get - val currentNodeToLocalTaskCount = getNodesWithLocalTaskCounts() - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - if (currentTotalRegisteredExecutors < runningExecutorsToPods.size) { - logDebug("Waiting for pending executors before scaling") - } else if (currentTotalExpectedExecutors <= runningExecutorsToPods.size) { - logDebug("Maximum allowed executor limit reached. Not scaling up further.") - } else { - for (_ <- 0 until math.min( - currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { - val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorConf = KubernetesConf.createExecutorConf( - conf, - executorId, - applicationId(), - driverPod) - val executorPod = executorBuilder.buildFromFeatures(executorConf) - val podWithAttachedContainer = new PodBuilder(executorPod.pod) - .editOrNewSpec() - .addToContainers(executorPod.container) - .endSpec() - .build() - - executorsToAllocate(executorId) = podWithAttachedContainer - logInfo( - s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") - } - } - } - - val allocatedExecutors = executorsToAllocate.mapValues { pod => - Utils.tryLog { - kubernetesClient.pods().create(pod) - } - } - - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - allocatedExecutors.map { - case (executorId, attemptedAllocatedExecutor) => - attemptedAllocatedExecutor.map { successfullyAllocatedExecutor => - runningExecutorsToPods.put(executorId, successfullyAllocatedExecutor) - } - } - } - } - - def handleDisconnectedExecutors(): Unit = { - // For each disconnected executor, synchronize with the loss reasons that may have been found - // by the executor pod watcher. If the loss reason was discovered by the watcher, - // inform the parent class with removeExecutor. - disconnectedPodsByExecutorIdPendingRemoval.asScala.foreach { - case (executorId, executorPod) => - val knownExitReason = Option(podsWithKnownExitReasons.remove( - executorPod.getMetadata.getName)) - knownExitReason.fold { - removeExecutorOrIncrementLossReasonCheckCount(executorId) - } { executorExited => - logWarning(s"Removing executor $executorId with loss reason " + executorExited.message) - removeExecutor(executorId, executorExited) - // We don't delete the pod running the executor that has an exit condition caused by - // the application from the Kubernetes API server. This allows users to debug later on - // through commands such as "kubectl logs " and - // "kubectl describe pod ". Note that exited containers have terminated and - // therefore won't take CPU and memory resources. - // Otherwise, the executor pod is marked to be deleted from the API server. - if (executorExited.exitCausedByApp) { - logInfo(s"Executor $executorId exited because of the application.") - deleteExecutorFromDataStructures(executorId) - } else { - logInfo(s"Executor $executorId failed because of a framework error.") - deleteExecutorFromClusterAndDataStructures(executorId) - } - } - } - } - - def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { - val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) - if (reasonCheckCount >= executorLostReasonCheckMaxAttempts) { - removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) - deleteExecutorFromClusterAndDataStructures(executorId) - } else { - executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) - } - } - - def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { - deleteExecutorFromDataStructures(executorId).foreach { pod => - kubernetesClient.pods().delete(pod) - } - } - - def deleteExecutorFromDataStructures(executorId: String): Option[Pod] = { - disconnectedPodsByExecutorIdPendingRemoval.remove(executorId) - executorReasonCheckAttemptCounts -= executorId - podsWithKnownExitReasons.remove(executorId) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.remove(executorId).orElse { - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - } - - override def sufficientResourcesRegistered(): Boolean = { - totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio + // Allow removeExecutor to be accessible by ExecutorPodsLifecycleEventHandler + private[k8s] def doRemoveExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + removeExecutor(executorId, reason) } override def start(): Unit = { super.start() - executorWatchResource.set( - kubernetesClient - .pods() - .withLabel(SPARK_APP_ID_LABEL, applicationId()) - .watch(new ExecutorPodsWatcher())) - - allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable, 0L, podAllocationInterval, TimeUnit.MILLISECONDS) - if (!Utils.isDynamicAllocationEnabled(conf)) { - doRequestTotalExecutors(initialExecutors) + podAllocator.setTotalExpectedExecutors(initialExecutors) } + lifecycleEventHandler.start(this) + podAllocator.start(applicationId()) + watchEvents.start(applicationId()) + pollEvents.start(applicationId()) } override def stop(): Unit = { - // stop allocation of new resources and caches. - allocatorExecutor.shutdown() - allocatorExecutor.awaitTermination(30, TimeUnit.SECONDS) - - // send stop message to executors so they shut down cleanly super.stop() - try { - val resource = executorWatchResource.getAndSet(null) - if (resource != null) { - resource.close() - } - } catch { - case e: Throwable => logWarning("Failed to close the executor pod watcher", e) + Utils.tryLogNonFatalError { + snapshotsStore.stop() } - // then delete the executor pods Utils.tryLogNonFatalError { - deleteExecutorPodsOnStop() - executorPodsByIPs.clear() + watchEvents.stop() } + Utils.tryLogNonFatalError { - logInfo("Closing kubernetes client") - kubernetesClient.close() + pollEvents.stop() } - } - /** - * @return A map of K8s cluster nodes to the number of tasks that could benefit from data - * locality if an executor launches on the cluster node. - */ - private def getNodesWithLocalTaskCounts() : Map[String, Int] = { - val nodeToLocalTaskCount = synchronized { - mutable.Map[String, Int]() ++ hostToLocalTaskCount + Utils.tryLogNonFatalError { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .delete() } - for (pod <- executorPodsByIPs.values().asScala) { - // Remove cluster nodes that are running our executors already. - // TODO: This prefers spreading out executors across nodes. In case users want - // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut - // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html - nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty || - nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty || - nodeToLocalTaskCount.remove( - InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty + Utils.tryLogNonFatalError { + ThreadUtils.shutdown(requestExecutorsService) + } + + Utils.tryLogNonFatalError { + kubernetesClient.close() } - nodeToLocalTaskCount.toMap[String, Int] } override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] { - totalExpectedExecutors.set(requestedTotal) + // TODO when we support dynamic allocation, the pod allocator should be told to process the + // current snapshot in order to decrease/increase the number of executors accordingly. + podAllocator.setTotalExpectedExecutors(requestedTotal) true } - override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { - val podsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - executorIds.flatMap { executorId => - runningExecutorsToPods.remove(executorId) match { - case Some(pod) => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - Some(pod) - - case None => - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - - kubernetesClient.pods().delete(podsToDelete: _*) - true + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio } - private def deleteExecutorPodsOnStop(): Unit = { - val executorPodsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - val runningExecutorPodsCopy = Seq(runningExecutorsToPods.values.toSeq: _*) - runningExecutorsToPods.clear() - runningExecutorPodsCopy - } - kubernetesClient.pods().delete(executorPodsToDelete: _*) + override def getExecutorIds(): Seq[String] = synchronized { + super.getExecutorIds() } - private class ExecutorPodsWatcher extends Watcher[Pod] { - - private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1 - - override def eventReceived(action: Action, pod: Pod): Unit = { - val podName = pod.getMetadata.getName - val podIP = pod.getStatus.getPodIP - - action match { - case Action.MODIFIED if (pod.getStatus.getPhase == "Running" - && pod.getMetadata.getDeletionTimestamp == null) => - val clusterNodeName = pod.getSpec.getNodeName - logInfo(s"Executor pod $podName ready, launched at $clusterNodeName as IP $podIP.") - executorPodsByIPs.put(podIP, pod) - - case Action.DELETED | Action.ERROR => - val executorId = getExecutorId(pod) - logDebug(s"Executor pod $podName at IP $podIP was at $action.") - if (podIP != null) { - executorPodsByIPs.remove(podIP) - } - - val executorExitReason = if (action == Action.ERROR) { - logWarning(s"Received error event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnError(pod) - } else if (action == Action.DELETED) { - logWarning(s"Received delete event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnDelete(pod) - } else { - throw new IllegalStateException( - s"Unknown action that should only be DELETED or ERROR: $action") - } - podsWithKnownExitReasons.put(pod.getMetadata.getName, executorExitReason) - - if (!disconnectedPodsByExecutorIdPendingRemoval.containsKey(executorId)) { - log.warn(s"Executor with id $executorId was not marked as disconnected, but the " + - s"watch received an event of type $action for this executor. The executor may " + - "have failed to start in the first place and never registered with the driver.") - } - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - - case _ => logDebug(s"Received event of executor pod $podName: " + action) - } - } - - override def onClose(cause: KubernetesClientException): Unit = { - logDebug("Executor pod watch closed.", cause) - } - - private def getExecutorExitStatus(pod: Pod): Int = { - val containerStatuses = pod.getStatus.getContainerStatuses - if (!containerStatuses.isEmpty) { - // we assume the first container represents the pod status. This assumption may not hold - // true in the future. Revisit this if side-car containers start running inside executor - // pods. - getExecutorExitStatus(containerStatuses.get(0)) - } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS - } - - private def getExecutorExitStatus(containerStatus: ContainerStatus): Int = { - Option(containerStatus.getState).map { containerState => - Option(containerState.getTerminated).map { containerStateTerminated => - containerStateTerminated.getExitCode.intValue() - }.getOrElse(UNKNOWN_EXIT_CODE) - }.getOrElse(UNKNOWN_EXIT_CODE) - } - - private def isPodAlreadyReleased(pod: Pod): Boolean = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - !runningExecutorsToPods.contains(executorId) - } - } - - private def executorExitReasonOnError(pod: Pod): ExecutorExited = { - val containerExitStatus = getExecutorExitStatus(pod) - // container was probably actively killed by the driver. - if (isPodAlreadyReleased(pod)) { - ExecutorExited(containerExitStatus, exitCausedByApp = false, - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination " + - "request.") - } else { - val containerExitReason = s"Pod ${pod.getMetadata.getName}'s executor container " + - s"exited with exit status code $containerExitStatus." - ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) - } - } - - private def executorExitReasonOnDelete(pod: Pod): ExecutorExited = { - val exitMessage = if (isPodAlreadyReleased(pod)) { - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." - } else { - s"Pod ${pod.getMetadata.getName} deleted or lost." - } - ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) - } - - private def getExecutorId(pod: Pod): String = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - require(executorId != null, "Unexpected pod metadata; expected all executor pods " + - s"to have label $SPARK_EXECUTOR_ID_LABEL.") - executorId - } + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .withLabelIn(SPARK_EXECUTOR_ID_LABEL, executorIds: _*) + .delete() + // Don't do anything else - let event handling from the Kubernetes API do the Spark changes } override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { new KubernetesDriverEndpoint(rpcEnv, properties) } - private class KubernetesDriverEndpoint( - rpcEnv: RpcEnv, - sparkProperties: Seq[(String, String)]) + private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends DriverEndpoint(rpcEnv, sparkProperties) { override def onDisconnected(rpcAddress: RpcAddress): Unit = { - addressToExecutorId.get(rpcAddress).foreach { executorId => - if (disableExecutor(executorId)) { - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.get(executorId).foreach { pod => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - } - } - } - } + // Don't do anything besides disabling the executor - allow the Kubernetes API events to + // drive the rest of the lifecycle decisions + // TODO what if we disconnect from a networking issue? Probably want to mark the executor + // to be deleted eventually. + addressToExecutorId.get(rpcAddress).foreach(disableExecutor) } } -} -private object KubernetesClusterSchedulerBackend { - private val UNKNOWN_EXIT_CODE = -1 } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 22568fe7ea3be..364b6fb367722 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,21 +17,42 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( - provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = + provideBasicStep: (KubernetesConf [KubernetesExecutorSpecificConf]) + => BasicExecutorFeatureStep = new BasicExecutorFeatureStep(_), - provideSecretsStep: - (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = - new MountSecretsFeatureStep(_)) { + provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => MountSecretsFeatureStep = + new MountSecretsFeatureStep(_), + provideEnvSecretsStep: + (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = + new EnvSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = + new LocalDirsFeatureStep(_), + provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountVolumesFeatureStep) = + new MountVolumesFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { - val baseFeatures = Seq(provideBasicStep(kubernetesConf)) - val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) - } else baseFeatures + + val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) + val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Seq(provideSecretsStep(kubernetesConf)) + } else Nil + val secretEnvFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Seq(provideEnvSecretsStep(kubernetesConf)) + } else Nil + val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + Seq(provideVolumesStep(kubernetesConf)) + } else Nil + + val allFeatures = baseFeatures ++ secretFeature ++ secretEnvFeature ++ volumesFeature + var executorPod = SparkPod.initialPod() for (feature <- allFeatures) { executorPod = feature.configurePod(executorPod) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala new file mode 100644 index 0000000000000..527fc6b0d8f87 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, HasMetadata, Pod, PodList} +import io.fabric8.kubernetes.client.{Watch, Watcher} +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} + +object Fabric8Aliases { + type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + type LABELED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + type SINGLE_POD = PodResource[Pod, DoneablePod] + type RESOURCE_LIST = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ + HasMetadata, Boolean] +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index f10202f7a3546..661f942435921 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -22,7 +22,7 @@ import io.fabric8.kubernetes.api.model.{LocalObjectReferenceBuilder, PodBuilder} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.deploy.k8s.submit._ class KubernetesConfSuite extends SparkFunSuite { @@ -40,6 +40,9 @@ class KubernetesConfSuite extends SparkFunSuite { private val SECRET_NAMES_TO_MOUNT_PATHS = Map( "secret1" -> "/mnt/secrets/secret1", "secret2" -> "/mnt/secrets/secret2") + private val SECRET_ENV_VARS = Map( + "envName1" -> "name1:key1", + "envName2" -> "name2:key2") private val CUSTOM_ENVS = Map( "customEnvKey1" -> "customEnvValue1", "customEnvKey2" -> "customEnvValue2") @@ -53,9 +56,10 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - None, + mainAppResource = None, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(conf.appId === APP_ID) assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap) assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX) @@ -76,7 +80,8 @@ class KubernetesConfSuite extends SparkFunSuite { APP_ID, mainAppJar, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(kubernetesConfWithMainJar.sparkConf.get("spark.jars") .split(",") === Array("local:///opt/spark/jar1.jar", "local:///opt/spark/main.jar")) @@ -85,15 +90,59 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - None, + mainAppResource = None, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(kubernetesConfWithoutMainJar.sparkConf.get("spark.jars").split(",") === Array("local:///opt/spark/jar1.jar")) + assert(kubernetesConfWithoutMainJar.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.1) } - test("Resolve driver labels, annotations, secret mount paths, and envs.") { + test("Creating driver conf with a python primary file") { + val mainResourceFile = "local:///opt/spark/main.py" + val inputPyFiles = Array("local:///opt/spark/example2.py", "local:///example3.py") val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + .set("spark.files", "local:///opt/spark/example4.py") + val mainAppResource = Some(PythonMainAppResource(mainResourceFile)) + val kubernetesConfWithMainResource = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppResource, + MAIN_CLASS, + APP_ARGS, + Some(inputPyFiles.mkString(","))) + assert(kubernetesConfWithMainResource.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + assert(kubernetesConfWithMainResource.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.4) + assert(kubernetesConfWithMainResource.sparkFiles + === Array("local:///opt/spark/example4.py", mainResourceFile) ++ inputPyFiles) + } + + test("Testing explicit setting of memory overhead on non-JVM tasks") { + val sparkConf = new SparkConf(false) + .set(MEMORY_OVERHEAD_FACTOR, 0.3) + + val mainResourceFile = "local:///opt/spark/main.py" + val mainAppResource = Some(PythonMainAppResource(mainResourceFile)) + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppResource, + MAIN_CLASS, + APP_ARGS, + None) + assert(conf.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.3) + } + + test("Resolve driver labels, annotations, secret mount paths, envs, and memory overhead") { + val sparkConf = new SparkConf(false) + .set(MEMORY_OVERHEAD_FACTOR, 0.3) CUSTOM_LABELS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$key", value) } @@ -103,6 +152,9 @@ class KubernetesConfSuite extends SparkFunSuite { SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value) } + SECRET_ENV_VARS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX$key", value) + } CUSTOM_ENVS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value) } @@ -112,16 +164,19 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - None, + mainAppResource = None, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(conf.roleLabels === Map( SPARK_APP_ID_LABEL -> APP_ID, SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) ++ CUSTOM_LABELS) assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) assert(conf.roleEnvs === CUSTOM_ENVS) + assert(conf.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.3) } test("Basic executor translated fields.") { @@ -155,6 +210,9 @@ class KubernetesConfSuite extends SparkFunSuite { CUSTOM_ANNOTATIONS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value) } + SECRET_ENV_VARS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX$key", value) + } SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value) } @@ -170,6 +228,6 @@ class KubernetesConfSuite extends SparkFunSuite { SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) } - } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala new file mode 100644 index 0000000000000..d795d159773a8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class KubernetesVolumeUtilsSuite extends SparkFunSuite { + test("Parses hostPath volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesHostPathVolumeConf] === + KubernetesHostPathVolumeConf("/hostPath")) + } + + test("Parses persistentVolumeClaim volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") + sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimeName") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === + KubernetesPVCVolumeConf("claimeName")) + } + + test("Parses emptyDir volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + sparkConf.set("test.emptyDir.volumeName.options.medium", "medium") + sparkConf.set("test.emptyDir.volumeName.options.sizeLimit", "5G") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] === + KubernetesEmptyDirVolumeConf(Some("medium"), Some("5G"))) + } + + test("Parses emptyDir volume options can be optional") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] === + KubernetesEmptyDirVolumeConf(None, None)) + } + + test("Defaults optional readOnly to false") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.mountReadOnly === false) + } + + test("Gracefully fails on missing mount key") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mnt.path", "/path") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.isFailure === true) + assert(volumeSpec.failed.get.getMessage === "emptyDir.volumeName.mount.path") + } + + test("Gracefully fails on missing option key") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") + sparkConf.set("test.hostPath.volumeName.options.pth", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.isFailure === true) + assert(volumeSpec.failed.get.getMessage === "hostPath.volumeName.options.path") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index eee85b8baa730..165f46a07df2f 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -24,6 +24,8 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.deploy.k8s.submit.PythonMainAppResource class BasicDriverFeatureStepSuite extends SparkFunSuite { @@ -33,6 +35,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" private val APP_NAME = "spark-test" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" + private val PY_MAIN_CLASS = "org.apache.spark.deploy.PythonRunner" private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") private val CUSTOM_ANNOTATION_KEY = "customAnnotation" private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" @@ -47,6 +50,12 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { TEST_IMAGE_PULL_SECRETS.map { secret => new LocalObjectReferenceBuilder().withName(secret).build() } + private val emptyDriverSpecificConf = KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS) + test("Check the pod respects all configurations from the user.") { val sparkConf = new SparkConf() @@ -59,17 +68,16 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) val kubernetesConf = KubernetesConf( sparkConf, - KubernetesDriverSpecificConf( - None, - APP_NAME, - MAIN_CLASS, - APP_ARGS), + emptyDriverSpecificConf, RESOURCE_NAME_PREFIX, APP_ID, DRIVER_LABELS, DRIVER_ANNOTATIONS, Map.empty, - DRIVER_ENVS) + Map.empty, + DRIVER_ENVS, + Nil, + Seq.empty[String]) val featureStep = new BasicDriverFeatureStep(kubernetesConf) val basePod = SparkPod.initialPod() @@ -109,7 +117,6 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS) assert(configuredPod.pod.getSpec.getRestartPolicy === "Never") - val expectedSparkConf = Map( KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", "spark.app.id" -> APP_ID, @@ -118,6 +125,52 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) } + test("Check appropriate entrypoint rerouting for various bindings") { + val javaSparkConf = new SparkConf() + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "4g") + .set(CONTAINER_IMAGE, "spark-driver:latest") + val pythonSparkConf = new SparkConf() + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "4g") + .set(CONTAINER_IMAGE, "spark-driver:latest") + val javaKubernetesConf = KubernetesConf( + javaSparkConf, + KubernetesDriverSpecificConf( + Some(JavaMainAppResource("")), + APP_NAME, + PY_MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty, + DRIVER_ENVS, + Nil, + Seq.empty[String]) + val pythonKubernetesConf = KubernetesConf( + pythonSparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("")), + APP_NAME, + PY_MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty, + DRIVER_ENVS, + Nil, + Seq.empty[String]) + val javaFeatureStep = new BasicDriverFeatureStep(javaKubernetesConf) + val pythonFeatureStep = new BasicDriverFeatureStep(pythonKubernetesConf) + val basePod = SparkPod.initialPod() + val configuredJavaPod = javaFeatureStep.configurePod(basePod) + val configuredPythonPod = pythonFeatureStep.configurePod(basePod) + } + test("Additional system properties resolve jars and set cluster-mode confs.") { val allJars = Seq("local:///opt/spark/jar1.jar", "hdfs:///opt/spark/jar2.jar") val allFiles = Seq("https://localhost:9000/file1.txt", "local:///opt/spark/file2.txt") @@ -128,17 +181,17 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(CONTAINER_IMAGE, "spark-driver:latest") val kubernetesConf = KubernetesConf( sparkConf, - KubernetesDriverSpecificConf( - None, - APP_NAME, - MAIN_CLASS, - APP_ARGS), + emptyDriverSpecificConf, RESOURCE_NAME_PREFIX, APP_ID, DRIVER_LABELS, DRIVER_ANNOTATIONS, Map.empty, - Map.empty) + Map.empty, + DRIVER_ENVS, + Nil, + allFiles) + val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index a764f7630b5c8..a44fa1f2ffc63 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -87,7 +87,10 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, - Map.empty)) + Map.empty, + Map.empty, + Nil, + Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) // The executor pod name and default labels. @@ -124,7 +127,10 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, - Map.empty)) + Map.empty, + Map.empty, + Nil, + Seq.empty[String])) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -142,7 +148,10 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, - Map("qux" -> "quux"))) + Map.empty, + Map("qux" -> "quux"), + Nil, + Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) checkEnv(executor, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 9f817d3bfc79a..7e916b3854404 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -59,7 +59,10 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) assert(kubernetesCredentialsStep.getAdditionalPodSystemProperties().isEmpty) @@ -88,7 +91,10 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -124,7 +130,10 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index c299d56865ec0..8b91e93eecd8c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -65,7 +65,10 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Map.empty, + Nil, + Seq.empty[String])) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) assert(configurationStep.getAdditionalKubernetesResources().head.isInstanceOf[Service]) @@ -94,7 +97,10 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Map.empty, + Nil, + Seq.empty[String])) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX val expectedHostName = s"$expectedServiceName.my-namespace.svc" @@ -113,7 +119,10 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Map.empty, + Nil, + Seq.empty[String])) val resolvedService = configurationStep .getAdditionalKubernetesResources() .head @@ -141,7 +150,10 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, - Map.empty), + Map.empty, + Map.empty, + Nil, + Seq.empty[String]), clock) val driverService = configurationStep .getAdditionalKubernetesResources() @@ -166,7 +178,10 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, - Map.empty), + Map.empty, + Map.empty, + Nil, + Seq.empty[String]), clock) fail("The driver bind address should not be allowed.") } catch { @@ -189,7 +204,10 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, - Map.empty), + Map.empty, + Map.empty, + Nil, + Seq.empty[String]), clock) fail("The driver host address should not be allowed.") } catch { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala new file mode 100644 index 0000000000000..1c8d84b76c56b --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.PodBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class EnvSecretsFeatureStepSuite extends SparkFunSuite{ + private val KEY_REF_NAME_FOO = "foo" + private val KEY_REF_NAME_BAR = "bar" + private val KEY_REF_KEY_FOO = "key_foo" + private val KEY_REF_KEY_BAR = "key_bar" + private val ENV_NAME_FOO = "MY_FOO" + private val ENV_NAME_BAR = "MY_bar" + + test("sets up all keyRefs") { + val baseDriverPod = SparkPod.initialPod() + val envVarsToKeys = Map( + ENV_NAME_BAR -> s"${KEY_REF_NAME_BAR}:${KEY_REF_KEY_BAR}", + ENV_NAME_FOO -> s"${KEY_REF_NAME_FOO}:${KEY_REF_KEY_FOO}") + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + "resource-name-prefix", + "app-id", + Map.empty, + Map.empty, + Map.empty, + envVarsToKeys, + Map.empty, + Nil, + Seq.empty[String]) + + val step = new EnvSecretsFeatureStep(kubernetesConf) + val driverContainerWithEnvSecrets = step.configurePod(baseDriverPod).container + + val expectedVars = + Seq(s"${ENV_NAME_BAR}", s"${ENV_NAME_FOO}") + + expectedVars.foreach { envName => + assert(KubernetesFeaturesTestUtils.containerHasEnvVar(driverContainerWithEnvSecrets, envName)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala index 27bff74ce38af..f90380e30e52a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.{HasMetadata, PodBuilder, SecretBuilder} +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, SecretBuilder} import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -58,4 +60,7 @@ object KubernetesFeaturesTestUtils { .build()) } + def containerHasEnvVar(container: Container, envVarName: String): Boolean = { + container.getEnv.asScala.exists(envVar => envVar.getName == envVarName) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala new file mode 100644 index 0000000000000..a339827b819a9 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{EnvVarBuilder, VolumeBuilder, VolumeMountBuilder} +import org.mockito.Mockito +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} + +class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { + private val defaultLocalDir = "/var/data/default-local-dir" + private var sparkConf: SparkConf = _ + private var kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf] = _ + + before { + val realSparkConf = new SparkConf(false) + sparkConf = Mockito.spy(realSparkConf) + kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + "app-name", + "main", + Seq.empty), + "resource", + "app-id", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) + } + + test("Resolve to default local dir if neither env nor configuration are set") { + Mockito.doReturn(null).when(sparkConf).get("spark.local.dir") + Mockito.doReturn(null).when(sparkConf).getenv("SPARK_LOCAL_DIRS") + val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0) === + new VolumeBuilder() + .withName(s"spark-local-dir-1") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.container.getVolumeMounts.size === 1) + assert(configuredPod.container.getVolumeMounts.get(0) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-1") + .withMountPath(defaultLocalDir) + .build()) + assert(configuredPod.container.getEnv.size === 1) + assert(configuredPod.container.getEnv.get(0) === + new EnvVarBuilder() + .withName("SPARK_LOCAL_DIRS") + .withValue(defaultLocalDir) + .build()) + } + + test("Use configured local dirs split on comma if provided.") { + Mockito.doReturn("/var/data/my-local-dir-1,/var/data/my-local-dir-2") + .when(sparkConf).getenv("SPARK_LOCAL_DIRS") + val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size === 2) + assert(configuredPod.pod.getSpec.getVolumes.get(0) === + new VolumeBuilder() + .withName(s"spark-local-dir-1") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.pod.getSpec.getVolumes.get(1) === + new VolumeBuilder() + .withName(s"spark-local-dir-2") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.container.getVolumeMounts.size === 2) + assert(configuredPod.container.getVolumeMounts.get(0) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-1") + .withMountPath("/var/data/my-local-dir-1") + .build()) + assert(configuredPod.container.getVolumeMounts.get(1) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-2") + .withMountPath("/var/data/my-local-dir-2") + .build()) + assert(configuredPod.container.getEnv.size === 1) + assert(configuredPod.container.getEnv.get(0) === + new EnvVarBuilder() + .withName("SPARK_LOCAL_DIRS") + .withValue("/var/data/my-local-dir-1,/var/data/my-local-dir-2") + .build()) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 9d02f56cc206d..2b49b72dfa569 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -41,7 +41,10 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, secretNamesToMountPaths, - Map.empty) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) val step = new MountSecretsFeatureStep(kubernetesConf) val driverPodWithSecretsMounted = step.configurePod(baseDriverPod).pod diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala new file mode 100644 index 0000000000000..d309aa94ec115 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class MountVolumesFeatureStepSuite extends SparkFunSuite { + private val sparkConf = new SparkConf(false) + private val emptyKubernetesConf = KubernetesConf( + sparkConf = sparkConf, + roleSpecificConf = KubernetesDriverSpecificConf( + None, + "app-name", + "main", + Seq.empty), + appResourceNamePrefix = "resource", + appId = "app-id", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Nil, + sparkFiles = Nil) + + test("Mounts hostPath volumes") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/hostPath/tmp") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0).getHostPath.getPath === "/hostPath/tmp") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts pesistentVolumeClaims") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName === "pvcClaim") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === true) + + } + + test("Mounts emptyDir") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G")) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir + assert(emptyDir.getMedium === "Memory") + assert(emptyDir.getSizeLimit.getAmount === "6G") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts emptyDir with no options") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir + assert(emptyDir.getMedium === "") + assert(emptyDir.getSizeLimit.getAmount === null) + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts multiple volumes") { + val hpVolumeConf = KubernetesVolumeSpec( + "hpVolume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/hostPath/tmp") + ) + val pvcVolumeConf = KubernetesVolumeSpec( + "checkpointVolume", + "/checkpoints", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val volumesConf = hpVolumeConf :: pvcVolumeConf :: Nil + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumesConf) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 2) + assert(configuredPod.container.getVolumeMounts.size() === 2) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..18874afe6e53a --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.PythonMainAppResource + +class JavaDriverFeatureStepSuite extends SparkFunSuite { + + test("Java Step modifies container correctly") { + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("local:///main.jar")), + "test-class", + "java-runner", + Seq("5 7")), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Nil, + sparkFiles = Seq.empty[String]) + + val step = new JavaDriverFeatureStep(kubernetesConf) + val driverPod = step.configurePod(baseDriverPod).pod + val driverContainerwithJavaStep = step.configurePod(baseDriverPod).container + assert(driverContainerwithJavaStep.getArgs.size === 7) + val args = driverContainerwithJavaStep + .getArgs.asScala + assert(args === List( + "driver", + "--properties-file", SPARK_CONF_PATH, + "--class", "test-class", + "spark-internal", "5 7")) + + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..a5dac6869327d --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.PythonMainAppResource + +class PythonDriverFeatureStepSuite extends SparkFunSuite { + + test("Python Step modifies container correctly") { + val expectedMainResource = "/main.py" + val mainResource = "local:///main.py" + val pyFiles = Seq("local:///example2.py", "local:///example3.py") + val expectedPySparkFiles = + "/example2.py:/example3.py" + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + .set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, mainResource) + .set(KUBERNETES_PYSPARK_PY_FILES, pyFiles.mkString(",")) + .set("spark.files", "local:///example.py") + .set(PYSPARK_MAJOR_PYTHON_VERSION, "2") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("local:///main.py")), + "test-app", + "python-runner", + Seq("5 7")), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Nil, + sparkFiles = Seq.empty[String]) + + val step = new PythonDriverFeatureStep(kubernetesConf) + val driverPod = step.configurePod(baseDriverPod).pod + val driverContainerwithPySpark = step.configurePod(baseDriverPod).container + assert(driverContainerwithPySpark.getEnv.size === 4) + val envs = driverContainerwithPySpark + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(ENV_PYSPARK_PRIMARY) === expectedMainResource) + assert(envs(ENV_PYSPARK_FILES) === expectedPySparkFiles) + assert(envs(ENV_PYSPARK_ARGS) === "5 7") + assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "2") + } + test("Python Step testing empty pyfiles") { + val mainResource = "local:///main.py" + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + .set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, mainResource) + .set(PYSPARK_MAJOR_PYTHON_VERSION, "3") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("local:///main.py")), + "test-class-py", + "python-runner", + Seq.empty[String]), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Nil, + sparkFiles = Seq.empty[String]) + val step = new PythonDriverFeatureStep(kubernetesConf) + val driverContainerwithPySpark = step.configurePod(baseDriverPod).container + val args = driverContainerwithPySpark + .getArgs.asScala + assert(driverContainerwithPySpark.getArgs.size === 5) + assert(args === List( + "driver-py", + "--properties-file", SPARK_CONF_PATH, + "--class", "test-class-py")) + val envs = driverContainerwithPySpark + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "3") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index c1b203e03a357..4d8e79189ff32 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -103,15 +104,11 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { .build() } - private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ - HasMetadata, Boolean] - private type Pods = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - @Mock private var kubernetesClient: KubernetesClient = _ @Mock - private var podOperations: Pods = _ + private var podOperations: PODS = _ @Mock private var namedPods: PodResource[Pod, DoneablePod] = _ @@ -123,7 +120,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private var driverBuilder: KubernetesDriverBuilder = _ @Mock - private var resourceList: ResourceList = _ + private var resourceList: RESOURCE_LIST = _ private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ @@ -142,7 +139,10 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) when(podOperations.withName(POD_NAME)).thenReturn(namedPods) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 161f9afe7bba9..046e578b94629 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -17,15 +17,22 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val CREDENTIALS_STEP_TYPE = "credentials" private val SERVICE_STEP_TYPE = "service" + private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val SECRETS_STEP_TYPE = "mount-secrets" + private val JAVA_STEP_TYPE = "java-bindings" + private val PYSPARK_STEP_TYPE = "pyspark-bindings" + private val ENV_SECRETS_STEP_TYPE = "env-secrets" + private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) @@ -36,21 +43,41 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) + private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) + private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val javaStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + JAVA_STEP_TYPE, classOf[JavaDriverFeatureStep]) + + private val pythonStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + PYSPARK_STEP_TYPE, classOf[PythonDriverFeatureStep]) + + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) + + private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) + private val builderUnderTest: KubernetesDriverBuilder = new KubernetesDriverBuilder( _ => basicFeatureStep, _ => credentialsStep, _ => serviceStep, - _ => secretsStep) + _ => secretsStep, + _ => envSecretsStep, + _ => localDirsStep, + _ => mountVolumesStep, + _ => javaStep, + _ => pythonStep) test("Apply fundamental steps all the time.") { val conf = KubernetesConf( new SparkConf(false), KubernetesDriverSpecificConf( - None, + Some(JavaMainAppResource("example.jar")), "test-app", "main", Seq.empty), @@ -59,12 +86,17 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE) + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + JAVA_STEP_TYPE) } test("Apply secrets step if secrets are present.") { @@ -80,15 +112,106 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map("secret" -> "secretMountPath"), - Map.empty) + Map("EnvName" -> "SecretName:secretKey"), + Map.empty, + Nil, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, - SECRETS_STEP_TYPE) + LOCAL_DIRS_STEP_TYPE, + SECRETS_STEP_TYPE, + ENV_SECRETS_STEP_TYPE, + JAVA_STEP_TYPE) } + test("Apply Java step if main resource is none.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + JAVA_STEP_TYPE) + } + + test("Apply Python step if main resource is python.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("example.py")), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + PYSPARK_STEP_TYPE) + } + + test("Apply volumes step if mounts are present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/path")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE, + JAVA_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) : Unit = { assert(resolvedSpec.systemProperties.size === stepTypes.size) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala new file mode 100644 index 0000000000000..f7721e6fd6388 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import scala.collection.mutable + +class DeterministicExecutorPodsSnapshotsStore extends ExecutorPodsSnapshotsStore { + + private val snapshotsBuffer = mutable.Buffer.empty[ExecutorPodsSnapshot] + private val subscribers = mutable.Buffer.empty[Seq[ExecutorPodsSnapshot] => Unit] + + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + subscribers += onNewSnapshots + } + + override def stop(): Unit = {} + + def notifySubscribers(): Unit = { + subscribers.foreach(_(snapshotsBuffer)) + snapshotsBuffer.clear() + } + + override def updatePod(updatedPod: Pod): Unit = { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + snapshotsBuffer += currentSnapshot + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + snapshotsBuffer += currentSnapshot + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala new file mode 100644 index 0000000000000..c6b667ed85e8c --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, Pod, PodBuilder} + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkPod + +object ExecutorLifecycleTestUtils { + + val TEST_SPARK_APP_ID = "spark-app-id" + + def failedExecutorWithoutDeletion(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("failed") + .addNewContainerStatus() + .withName("spark-executor") + .withImage("k8s-spark") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .addNewContainerStatus() + .withName("spark-executor-sidecar") + .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .withMessage("Executor failed.") + .withReason("Executor failed because of a thrown error.") + .endStatus() + .build() + } + + def pendingExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("pending") + .endStatus() + .build() + } + + def runningExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() + } + + def succeededExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("succeeded") + .endStatus() + .build() + } + + def deletedExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewMetadata() + .withNewDeletionTimestamp("523012521") + .endMetadata() + .build() + } + + def unknownExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("unknown") + .endStatus() + .build() + } + + def podWithAttachedContainerForId(executorId: Long): Pod = { + val sparkPod = executorPodWithId(executorId) + val podWithAttachedContainer = new PodBuilder(sparkPod.pod) + .editOrNewSpec() + .addToContainers(sparkPod.container) + .endSpec() + .build() + podWithAttachedContainer + } + + def executorPodWithId(executorId: Long): SparkPod = { + val pod = new PodBuilder() + .withNewMetadata() + .withName(s"spark-executor-$executorId") + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) + .endMetadata() + .build() + val container = new ContainerBuilder() + .withName("spark-executor") + .withImage("k8s-spark") + .build() + SparkPod(pod, container) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala new file mode 100644 index 0000000000000..0c19f5946b75f --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{never, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ +import org.apache.spark.util.ManualClock + +class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { + + private val driverPodName = "driver" + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(driverPodName) + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .withUid("driver-pod-uid") + .endMetadata() + .build() + + private val conf = new SparkConf().set(KUBERNETES_DRIVER_POD_NAME, driverPodName) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000L) + + private var waitForExecutorPodsClock: ManualClock = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var labeledPods: LABELED_PODS = _ + + @Mock + private var driverPodOperations: PodResource[Pod, DoneablePod] = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + + private var podsAllocatorUnderTest: ExecutorPodsAllocator = _ + + before { + MockitoAnnotations.initMocks(this) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations) + when(driverPodOperations.get).thenReturn(driverPod) + when(executorBuilder.buildFromFeatures(kubernetesConfWithCorrectFields())) + .thenAnswer(executorPodAnswer()) + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + waitForExecutorPodsClock = new ManualClock(0L) + podsAllocatorUnderTest = new ExecutorPodsAllocator( + conf, executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock) + podsAllocatorUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Initially request executors in batches. Do not request another batch if the" + + " first has not finished.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (nextId <- 1 to podAllocationSize) { + verify(podOperations).create(podWithAttachedContainerForId(nextId)) + } + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("Request executors in batches. Allow another batch to be requested if" + + " all pending executors start running.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + snapshotsStore.notifySubscribers() + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations, times(podAllocationSize + 1)).create(any(classOf[Pod])) + } + + test("When a current batch reaches error states immediately, re-request" + + " them on the next batch.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + val failedPod = failedExecutorWithoutDeletion(podAllocationSize) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("When an executor is requested but the API does not report it in a reasonable time, retry" + + " requesting that executor.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + waitForExecutorPodsClock.setTime(podCreationTimeout + 1) + when(podOperations.withLabel(SPARK_EXECUTOR_ID_LABEL, "1")).thenReturn(labeledPods) + snapshotsStore.notifySubscribers() + verify(labeledPods).delete() + verify(podOperations).create(podWithAttachedContainerForId(2)) + } + + private def executorPodAnswer(): Answer[SparkPod] = { + new Answer[SparkPod] { + override def answer(invocation: InvocationOnMock): SparkPod = { + val k8sConf = invocation.getArgumentAt( + 0, classOf[KubernetesConf[KubernetesExecutorSpecificConf]]) + executorPodWithId(k8sConf.roleSpecificConf.executorId.toInt) + } + } + } + + private def kubernetesConfWithCorrectFields(): KubernetesConf[KubernetesExecutorSpecificConf] = + Matchers.argThat(new ArgumentMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any): Boolean = { + if (!argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]]) { + false + } else { + val k8sConf = argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + val executorSpecificConf = k8sConf.roleSpecificConf + val expectedK8sConf = KubernetesConf.createExecutorConf( + conf, + executorSpecificConf.executorId, + TEST_SPARK_APP_ID, + driverPod) + k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && + // Since KubernetesConf.createExecutorConf clones the SparkConf object, force + // deep equality comparison for the SparkConf object and use object equality + // comparison on all other fields. + k8sConf.copy(sparkConf = conf) == expectedK8sConf.copy(sparkConf = conf) + } + } + }) + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala new file mode 100644 index 0000000000000..562ace9f49d4d --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.CacheBuilder +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{mock, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfter { + + private var namedExecutorPods: mutable.Map[String, PodResource[Pod, DoneablePod]] = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + @Mock + private var schedulerBackend: KubernetesClusterSchedulerBackend = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + private var eventHandlerUnderTest: ExecutorPodsLifecycleManager = _ + + before { + MockitoAnnotations.initMocks(this) + val removedExecutorsCache = CacheBuilder.newBuilder().build[java.lang.Long, java.lang.Long] + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + namedExecutorPods = mutable.Map.empty[String, PodResource[Pod, DoneablePod]] + when(schedulerBackend.getExecutorIds()).thenReturn(Seq.empty[String]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(any(classOf[String]))).thenAnswer(namedPodsAnswer()) + eventHandlerUnderTest = new ExecutorPodsLifecycleManager( + new SparkConf(), + executorBuilder, + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + eventHandlerUnderTest.start(schedulerBackend) + } + + test("When an executor reaches error states immediately, remove from the scheduler backend.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName)).delete() + } + + test("Don't remove executors twice from Spark but remove from K8s repeatedly.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend, times(1)).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName), times(2)).delete() + } + + test("When the scheduler backend lists executor ids that aren't present in the cluster," + + " remove those executors from Spark.") { + when(schedulerBackend.getExecutorIds()).thenReturn(Seq("1")) + val msg = s"The executor with ID 1 was not found in the cluster but we didn't" + + s" get a reason why. Marking the executor as failed. The executor may have been" + + s" deleted but the driver missed the deletion event." + val expectedLossReason = ExecutorExited(-1, exitCausedByApp = false, msg) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + } + + private def exitReasonMessage(failedExecutorId: Int, failedPod: Pod): String = { + s""" + |The executor with id $failedExecutorId exited with exit code 1. + |The API gave the following brief reason: ${failedPod.getStatus.getReason} + |The API gave the following message: ${failedPod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${failedPod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] = { + new Answer[PodResource[Pod, DoneablePod]] { + override def answer(invocation: InvocationOnMock): PodResource[Pod, DoneablePod] = { + val podName = invocation.getArgumentAt(0, classOf[String]) + namedExecutorPods.getOrElseUpdate( + podName, mock(classOf[PodResource[Pod, DoneablePod]])) + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala new file mode 100644 index 0000000000000..1b26d6af296a5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit + +import io.fabric8.kubernetes.api.model.PodListBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsPollingSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + private val sparkConf = new SparkConf + + private val pollingInterval = sparkConf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + private var pollingExecutor: DeterministicScheduler = _ + private var pollingSourceUnderTest: ExecutorPodsPollingSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + pollingExecutor = new DeterministicScheduler() + pollingSourceUnderTest = new ExecutorPodsPollingSnapshotSource( + sparkConf, + kubernetesClient, + eventQueue, + pollingExecutor) + pollingSourceUnderTest.start(TEST_SPARK_APP_ID) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + } + + test("Items returned by the API should be pushed to the event queue") { + when(executorRoleLabeledPods.list()) + .thenReturn(new PodListBuilder() + .addToItems( + runningExecutor(1), + runningExecutor(2)) + .build()) + pollingExecutor.tick(pollingInterval, TimeUnit.MILLISECONDS) + verify(eventQueue).replaceSnapshot(Seq(runningExecutor(1), runningExecutor(2))) + + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala new file mode 100644 index 0000000000000..70e19c904eddb --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsSnapshotSuite extends SparkFunSuite { + + test("States are interpreted correctly from pod metadata.") { + val pods = Seq( + pendingExecutor(0), + runningExecutor(1), + succeededExecutor(2), + failedExecutorWithoutDeletion(3), + deletedExecutor(4), + unknownExecutor(5)) + val snapshot = ExecutorPodsSnapshot(pods) + assert(snapshot.executorPods === + Map( + 0L -> PodPending(pods(0)), + 1L -> PodRunning(pods(1)), + 2L -> PodSucceeded(pods(2)), + 3L -> PodFailed(pods(3)), + 4L -> PodDeleted(pods(4)), + 5L -> PodUnknown(pods(5)))) + } + + test("Updates add new pods for non-matching ids and edit existing pods for matching ids") { + val originalPods = Seq( + pendingExecutor(0), + runningExecutor(1)) + val originalSnapshot = ExecutorPodsSnapshot(originalPods) + val snapshotWithUpdatedPod = originalSnapshot.withUpdate(succeededExecutor(1)) + assert(snapshotWithUpdatedPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)))) + val snapshotWithNewPod = snapshotWithUpdatedPod.withUpdate(pendingExecutor(2)) + assert(snapshotWithNewPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)), + 2L -> PodPending(pendingExecutor(2)))) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala new file mode 100644 index 0000000000000..cf54b3c4eb329 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference + +import io.fabric8.kubernetes.api.model.{Pod, PodBuilder} +import org.jmock.lib.concurrent.DeterministicScheduler +import org.scalatest.BeforeAndAfter +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ + +class ExecutorPodsSnapshotsStoreSuite extends SparkFunSuite with BeforeAndAfter { + + private var eventBufferScheduler: DeterministicScheduler = _ + private var eventQueueUnderTest: ExecutorPodsSnapshotsStoreImpl = _ + + before { + eventBufferScheduler = new DeterministicScheduler() + eventQueueUnderTest = new ExecutorPodsSnapshotsStoreImpl(eventBufferScheduler) + } + + test("Subscribers get notified of events periodically.") { + val receivedSnapshots1 = mutable.Buffer.empty[ExecutorPodsSnapshot] + val receivedSnapshots2 = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots1 ++= _ + } + eventQueueUnderTest.addSubscriber(2000) { + receivedSnapshots2 ++= _ + } + + eventBufferScheduler.runUntilIdle() + assert(receivedSnapshots1 === Seq(ExecutorPodsSnapshot())) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + pushPodWithIndex(1) + // Force time to move forward so that the buffer is emitted, scheduling the + // processing task on the subscription executor... + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + // ... then actually execute the subscribers. + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + // Don't repeat snapshots + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + pushPodWithIndex(2) + pushPodWithIndex(3) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots1 === receivedSnapshots2) + } + + test("Even without sending events, initially receive an empty buffer.") { + val receivedInitialSnapshot = new AtomicReference[Seq[ExecutorPodsSnapshot]](null) + eventQueueUnderTest.addSubscriber(1000) { + receivedInitialSnapshot.set + } + assert(receivedInitialSnapshot.get == null) + eventBufferScheduler.runUntilIdle() + assert(receivedInitialSnapshot.get === Seq(ExecutorPodsSnapshot())) + } + + test("Replacing the snapshot passes the new snapshot to subscribers.") { + val receivedSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots ++= _ + } + eventQueueUnderTest.updatePod(podWithIndex(1)) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + eventQueueUnderTest.replaceSnapshot(Seq(podWithIndex(2))) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(2))))) + } + + private def pushPodWithIndex(index: Int): Unit = + eventQueueUnderTest.updatePod(podWithIndex(index)) + + private def podWithIndex(index: Int): Pod = + new PodBuilder() + .editOrNewMetadata() + .withName(s"pod-$index") + .addToLabels(SPARK_EXECUTOR_ID_LABEL, index.toString) + .endMetadata() + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala new file mode 100644 index 0000000000000..ac1968b4ff810 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsWatchSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var watchConnection: Watch = _ + + private var watch: ArgumentCaptor[Watcher[Pod]] = _ + + private var watchSourceUnderTest: ExecutorPodsWatchSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + watch = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + when(executorRoleLabeledPods.watch(watch.capture())).thenReturn(watchConnection) + watchSourceUnderTest = new ExecutorPodsWatchSnapshotSource( + eventQueue, kubernetesClient) + watchSourceUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Watch events should be pushed to the snapshots store as snapshot updates.") { + watch.getValue.eventReceived(Action.ADDED, runningExecutor(1)) + watch.getValue.eventReceived(Action.MODIFIED, runningExecutor(2)) + verify(eventQueue).updatePod(runningExecutor(1)) + verify(eventQueue).updatePod(runningExecutor(2)) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 96065e83f069c..52e7a12dbaf06 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -16,85 +16,36 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} -import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.hamcrest.{BaseMatcher, Description, Matcher} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} -import org.mockito.Matchers.{any, eq => mockitoEq} -import org.mockito.Mockito.{doNothing, never, times, verify, when} +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{eq => mockitoEq} +import org.mockito.Mockito.{never, verify, when} import org.scalatest.BeforeAndAfter -import org.scalatest.mockito.MockitoSugar._ -import scala.collection.JavaConverters._ -import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.rpc._ -import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.{ExecutorKilled, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.ThreadUtils +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils.TEST_SPARK_APP_ID class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAndAfter { - private val APP_ID = "test-spark-app" - private val DRIVER_POD_NAME = "spark-driver-pod" - private val NAMESPACE = "test-namespace" - private val SPARK_DRIVER_HOST = "localhost" - private val SPARK_DRIVER_PORT = 7077 - private val POD_ALLOCATION_INTERVAL = "1m" - private val FIRST_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod1") - .endMetadata() - .withNewSpec() - .withNodeName("node1") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private val SECOND_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod2") - .endMetadata() - .withNewSpec() - .withNodeName("node2") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.101") - .endStatus() - .build() - - private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - private type LABELED_PODS = FilterWatchListDeletable[ - Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] - private type IN_NAMESPACE_PODS = NonNamespaceOperation[ - Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - - @Mock - private var sparkContext: SparkContext = _ - - @Mock - private var listenerBus: LiveListenerBus = _ - - @Mock - private var taskSchedulerImpl: TaskSchedulerImpl = _ + private val requestExecutorsService = new DeterministicScheduler() + private val sparkConf = new SparkConf(false) + .set("spark.executor.instances", "3") @Mock - private var allocatorExecutor: ScheduledExecutorService = _ + private var sc: SparkContext = _ @Mock - private var requestExecutorsService: ExecutorService = _ + private var rpcEnv: RpcEnv = _ @Mock - private var executorBuilder: KubernetesExecutorBuilder = _ + private var driverEndpointRef: RpcEndpointRef = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -103,347 +54,97 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var podOperations: PODS = _ @Mock - private var podsWithLabelOperations: LABELED_PODS = _ + private var labeledPods: LABELED_PODS = _ @Mock - private var podsInNamespace: IN_NAMESPACE_PODS = _ + private var taskScheduler: TaskSchedulerImpl = _ @Mock - private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + private var eventQueue: ExecutorPodsSnapshotsStore = _ @Mock - private var rpcEnv: RpcEnv = _ + private var podAllocator: ExecutorPodsAllocator = _ @Mock - private var driverEndpointRef: RpcEndpointRef = _ + private var lifecycleEventHandler: ExecutorPodsLifecycleManager = _ @Mock - private var executorPodsWatch: Watch = _ + private var watchEvents: ExecutorPodsWatchSnapshotSource = _ @Mock - private var successFuture: Future[Boolean] = _ + private var pollEvents: ExecutorPodsPollingSnapshotSource = _ - private var sparkConf: SparkConf = _ - private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ - private var allocatorRunnable: ArgumentCaptor[Runnable] = _ - private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ - - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(DRIVER_POD_NAME) - .addToLabels(SPARK_APP_ID_LABEL, APP_ID) - .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) - .endMetadata() - .build() + private var schedulerBackendUnderTest: KubernetesClusterSchedulerBackend = _ before { MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) - .set(KUBERNETES_NAMESPACE, NAMESPACE) - .set("spark.driver.host", SPARK_DRIVER_HOST) - .set("spark.driver.port", SPARK_DRIVER_PORT.toString) - .set(KUBERNETES_ALLOCATION_BATCH_DELAY.key, POD_ALLOCATION_INTERVAL) - executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) - allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) - requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + when(taskScheduler.sc).thenReturn(sc) + when(sc.conf).thenReturn(sparkConf) driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) - when(sparkContext.conf).thenReturn(sparkConf) - when(sparkContext.listenerBus).thenReturn(listenerBus) - when(taskSchedulerImpl.sc).thenReturn(sparkContext) - when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) - when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) - .thenReturn(executorPodsWatch) - when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) - when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) - when(podsWithDriverName.get()).thenReturn(driverPod) - when(allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable.capture(), - mockitoEq(0L), - mockitoEq(TimeUnit.MINUTES.toMillis(1)), - mockitoEq(TimeUnit.MILLISECONDS))).thenReturn(null) - // Creating Futures in Scala backed by a Java executor service resolves to running - // ExecutorService#execute (as opposed to submit) - doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) when(rpcEnv.setupEndpoint( mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) .thenReturn(driverEndpointRef) - - // Used by the CoarseGrainedSchedulerBackend when making RPC calls. - when(driverEndpointRef.ask[Boolean] - (any(classOf[Any])) - (any())).thenReturn(successFuture) - when(successFuture.failed).thenReturn(Future[Throwable] { - // emulate behavior of the Future.failed method. - throw new NoSuchElementException() - }(ThreadUtils.sameThread)) - } - - test("Basic lifecycle expectations when starting and stopping the scheduler.") { - val scheduler = newSchedulerBackend() - scheduler.start() - assert(executorPodsWatcherArgument.getValue != null) - assert(allocatorRunnable.getValue != null) - scheduler.stop() - verify(executorPodsWatch).close() - } - - test("Static allocation should request executors upon first allocator run.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations).create(secondResolvedPod) - } - - test("Killing executors deletes the executor pods") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - scheduler.doKillExecutors(Seq("2")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations).delete(secondResolvedPod) - verify(podOperations, never()).delete(firstResolvedPod) - } - - test("Executors should be requested in batches.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations, never()).create(secondResolvedPod) - val registerFirstExecutorMessage = RegisterExecutor( - "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - allocatorRunnable.getValue.run() - verify(podOperations).create(secondResolvedPod) - } - - test("Scaled down executors should be cleaned up") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - - // The scheduler backend spins up one executor pod. - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val resolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - // Request that there are 0 executors and trigger deletion from driver. - scheduler.doRequestTotalExecutors(0) - requestExecutorRunnable.getAllValues.asScala.last.run() - scheduler.doKillExecutors(Seq("1")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations, times(1)).delete(resolvedPod) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - - val exitedPod = exitPod(resolvedPod, 0) - executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) - allocatorRunnable.getValue.run() - - // No more deletion attempts of the executors. - // This is graceful termination and should not be detected as a failure. - verify(podOperations, times(1)).delete(resolvedPod) - verify(driverEndpointRef, times(1)).send( - RemoveExecutor("1", ExecutorExited( - 0, - exitCausedByApp = false, - s"Container in pod ${exitedPod.getMetadata.getName} exited from" + - s" explicit termination request."))) - } - - test("Executors that fail should not be deleted.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - executorPodsWatcherArgument.getValue.eventReceived( - Action.ERROR, exitPod(firstResolvedPod, 1)) - - // A replacement executor should be created but the error pod should persist. - val replacementPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - scheduler.doRequestTotalExecutors(1) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getAllValues.asScala.last.run() - verify(podOperations, never()).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", ExecutorExited( - 1, - exitCausedByApp = true, - s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + - " exit status code 1."))) - } - - test("Executors disconnected due to unknown reasons are deleted and replaced.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val executorLostReasonCheckMaxAttempts = sparkConf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - 1 to executorLostReasonCheckMaxAttempts foreach { _ => - allocatorRunnable.getValue.run() - verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + when(kubernetesClient.pods()).thenReturn(podOperations) + schedulerBackendUnderTest = new KubernetesClusterSchedulerBackend( + taskScheduler, + rpcEnv, + kubernetesClient, + requestExecutorsService, + eventQueue, + podAllocator, + lifecycleEventHandler, + watchEvents, + pollEvents) { + override def applicationId(): String = TEST_SPARK_APP_ID } - - val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) } - test("Executors that fail to start on the Kubernetes API call rebuild in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(firstResolvedPod)) - .thenThrow(new RuntimeException("test")) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Start all components") { + schedulerBackendUnderTest.start() + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator).start(TEST_SPARK_APP_ID) + verify(lifecycleEventHandler).start(schedulerBackendUnderTest) + verify(watchEvents).start(TEST_SPARK_APP_ID) + verify(pollEvents).start(TEST_SPARK_APP_ID) } - test("Executors that are initially created but the watch notices them fail are rebuilt" + - " in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(FIRST_EXECUTOR_POD)).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - executorPodsWatcherArgument.getValue.eventReceived(Action.ERROR, firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Stop all components") { + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + schedulerBackendUnderTest.stop() + verify(eventQueue).stop() + verify(watchEvents).stop() + verify(pollEvents).stop() + verify(labeledPods).delete() + verify(kubernetesClient).close() } - private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = { - new KubernetesClusterSchedulerBackend( - taskSchedulerImpl, - rpcEnv, - executorBuilder, - kubernetesClient, - allocatorExecutor, - requestExecutorsService) { - - override def applicationId(): String = APP_ID - } + test("Remove executor") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRemoveExecutor( + "1", ExecutorKilled) + verify(driverEndpointRef).send(RemoveExecutor("1", ExecutorKilled)) } - private def exitPod(basePod: Pod, exitCode: Int): Pod = { - new PodBuilder(basePod) - .editStatus() - .addNewContainerStatus() - .withNewState() - .withNewTerminated() - .withExitCode(exitCode) - .endTerminated() - .endState() - .endContainerStatus() - .endStatus() - .build() + test("Kill executors") { + schedulerBackendUnderTest.start() + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + when(labeledPods.withLabelIn(SPARK_EXECUTOR_ID_LABEL, "1", "2")).thenReturn(labeledPods) + schedulerBackendUnderTest.doKillExecutors(Seq("1", "2")) + verify(labeledPods, never()).delete() + requestExecutorsService.runNextPendingCommand() + verify(labeledPods).delete() } - private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Pod = { - val resolvedPod = new PodBuilder(expectedPod) - .editMetadata() - .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) - .endMetadata() - .build() - val resolvedContainer = new ContainerBuilder().build() - when(executorBuilder.buildFromFeatures(Matchers.argThat( - new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { - override def matches(argument: scala.Any) - : Boolean = { - argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && - argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] - .roleSpecificConf.executorId == executorId.toString - } - - override def describeTo(description: Description): Unit = {} - }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) - new PodBuilder(resolvedPod) - .editSpec() - .addToContainers(resolvedContainer) - .endSpec() - .build() + test("Request total executors") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRequestTotalExecutors(5) + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator, never()).setTotalExpectedExecutors(5) + requestExecutorsService.runNextPendingCommand() + verify(podAllocator).setTotalExpectedExecutors(5) } + } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index f5270623f8acc..d0b4127065eb7 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -19,21 +19,33 @@ package org.apache.spark.scheduler.cluster.k8s import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features._ class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" + private val ENV_SECRETS_STEP_TYPE = "env-secrets" + private val LOCAL_DIRS_STEP_TYPE = "local-dirs" + private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) + private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) + private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, - _ => mountSecretsStep) + _ => mountSecretsStep, + _ => envSecretsStep, + _ => localDirsStep, + _ => mountVolumesStep) test("Basic steps are consistently applied.") { val conf = KubernetesConf( @@ -45,8 +57,12 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, - Map.empty) - validateStepTypesApplied(builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE) + Map.empty, + Map.empty, + Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) } test("Apply secrets step if secrets are present.") { @@ -59,11 +75,42 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map("secret" -> "secretMountPath"), - Map.empty) + Map("secret-name" -> "secret-key"), + Map.empty, + Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + SECRETS_STEP_TYPE, + ENV_SECRETS_STEP_TYPE) + } + + test("Apply volumes step if mounts are present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/checkpoint")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, - SECRETS_STEP_TYPE) + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE) } private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile new file mode 100644 index 0000000000000..72bb9620b45de --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +ARG base_img +FROM $base_img +WORKDIR / +RUN mkdir ${SPARK_HOME}/python +COPY python/lib ${SPARK_HOME}/python/lib +# TODO: Investigate running both pip and pip3 via virtualenvs +RUN apk add --no-cache python && \ + apk add --no-cache python3 && \ + python -m ensurepip && \ + python3 -m ensurepip && \ + # We remove ensurepip since it adds no functionality since pip is + # installed on the image and it just takes up 1.6MB on the image + rm -r /usr/lib/python*/ensurepip && \ + pip install --upgrade pip setuptools && \ + # You may install with python3 packages by using pip3.6 + # Removed the .cache to save space + rm -r /root/.cache + +ENV PYTHONPATH ${SPARK_HOME}/python/lib/pyspark.zip:${SPARK_HOME}/python/lib/py4j-*.zip + +WORKDIR /opt/spark/work-dir +ENTRYPOINT [ "/opt/entrypoint.sh" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 3e166116aa3fd..8bdb0f7a10795 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -37,20 +37,46 @@ if [ -z "$uidentry" ] ; then fi SPARK_K8S_CMD="$1" -if [ -z "$SPARK_K8S_CMD" ]; then - echo "No command to execute has been provided." 1>&2 - exit 1 -fi -shift 1 +case "$SPARK_K8S_CMD" in + driver | driver-py | executor) + shift 1 + ;; + "") + ;; + *) + echo "Non-spark-on-k8s command provided, proceeding in pass-through mode..." + exec /sbin/tini -s -- "$@" + ;; +esac SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt -readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt -if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then - SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" +readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt + +if [ -n "$SPARK_EXTRA_CLASSPATH" ]; then + SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_EXTRA_CLASSPATH" fi -if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then - cp -R "$SPARK_MOUNTED_FILES_DIR/." . + +if [ -n "$PYSPARK_FILES" ]; then + PYTHONPATH="$PYTHONPATH:$PYSPARK_FILES" +fi + +PYSPARK_ARGS="" +if [ -n "$PYSPARK_APP_ARGS" ]; then + PYSPARK_ARGS="$PYSPARK_APP_ARGS" +fi + + +if [ "$PYSPARK_MAJOR_PYTHON_VERSION" == "2" ]; then + pyv="$(python -V 2>&1)" + export PYTHON_VERSION="${pyv:7}" + export PYSPARK_PYTHON="python" + export PYSPARK_DRIVER_PYTHON="python" +elif [ "$PYSPARK_MAJOR_PYTHON_VERSION" == "3" ]; then + pyv3="$(python3 -V 2>&1)" + export PYTHON_VERSION="${pyv3:7}" + export PYSPARK_PYTHON="python3" + export PYSPARK_DRIVER_PYTHON="python3" fi case "$SPARK_K8S_CMD" in @@ -62,11 +88,18 @@ case "$SPARK_K8S_CMD" in "$@" ) ;; - + driver-py) + CMD=( + "$SPARK_HOME/bin/spark-submit" + --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" + --deploy-mode client + "$@" $PYSPARK_PRIMARY $PYSPARK_ARGS + ) + ;; executor) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_JAVA_OPTS[@]}" + "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md new file mode 100644 index 0000000000000..b3863e6b7d1af --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -0,0 +1,52 @@ +--- +layout: global +title: Spark on Kubernetes Integration Tests +--- + +# Running the Kubernetes Integration Tests + +Note that the integration test framework is currently being heavily revised and +is subject to change. Note that currently the integration tests only run with Java 8. + +The simplest way to run the integration tests is to install and run Minikube, then run the following: + + dev/dev-run-integration-tests.sh + +The minimum tested version of Minikube is 0.23.0. The kube-dns addon must be enabled. Minikube should +run with a minimum of 3 CPUs and 4G of memory: + + minikube start --cpus 3 --memory 4096 + +You can download Minikube [here](https://github.com/kubernetes/minikube/releases). + +# Integration test customization + +Configuration of the integration test runtime is done through passing different arguments to the test script. The main useful options are outlined below. + +## Re-using Docker Images + +By default, the test framework will build new Docker images on every test execution. A unique image tag is generated, +and it is written to file at `target/imageTag.txt`. To reuse the images built in a previous run, or to use a Docker image tag +that you have built by other means already, pass the tag to the test script: + + dev/dev-run-integration-tests.sh --image-tag + +where if you still want to use images that were built before by the test framework: + + dev/dev-run-integration-tests.sh --image-tag $(cat target/imageTag.txt) + +## Spark Distribution Under Test + +The Spark code to test is handed to the integration test system via a tarball. Here is the option that is used to specify the tarball: + +* `--spark-tgz ` - set `` to point to a tarball containing the Spark distribution to test. + +TODO: Don't require the packaging of the built Spark artifacts into this tarball, just read them out of the current tree. + +## Customizing the Namespace and Service Account + +* `--namespace ` - set `` to the namespace in which the tests should be run. +* `--service-account ` - set `` to the name of the Kubernetes service account to +use in the namespace specified by the `--namespace`. The service account is expected to have permissions to get, list, watch, +and create pods. For clusters with RBAC turned on, it's important that the right permissions are granted to the service account +in the namespace through an appropriate role and role binding. A reference RBAC configuration is provided in `dev/spark-rbac.yaml`. diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh new file mode 100755 index 0000000000000..3acd0f5cd3349 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -0,0 +1,113 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +TEST_ROOT_DIR=$(git rev-parse --show-toplevel)/resource-managers/kubernetes/integration-tests + +cd "${TEST_ROOT_DIR}" + +DEPLOY_MODE="minikube" +IMAGE_REPO="docker.io/kubespark" +SPARK_TGZ="N/A" +IMAGE_TAG="N/A" +SPARK_MASTER= +NAMESPACE= +SERVICE_ACCOUNT= +INCLUDE_TAGS= +EXCLUDE_TAGS= + +# Parse arguments +while (( "$#" )); do + case $1 in + --image-repo) + IMAGE_REPO="$2" + shift + ;; + --image-tag) + IMAGE_TAG="$2" + shift + ;; + --deploy-mode) + DEPLOY_MODE="$2" + shift + ;; + --spark-tgz) + SPARK_TGZ="$2" + shift + ;; + --spark-master) + SPARK_MASTER="$2" + shift + ;; + --namespace) + NAMESPACE="$2" + shift + ;; + --service-account) + SERVICE_ACCOUNT="$2" + shift + ;; + --include-tags) + INCLUDE_TAGS="$2" + shift + ;; + --exclude-tags) + EXCLUDE_TAGS="$2" + shift + ;; + *) + break + ;; + esac + shift +done + +cd $TEST_ROOT_DIR + +properties=( + -Dspark.kubernetes.test.sparkTgz=$SPARK_TGZ \ + -Dspark.kubernetes.test.imageTag=$IMAGE_TAG \ + -Dspark.kubernetes.test.imageRepo=$IMAGE_REPO \ + -Dspark.kubernetes.test.deployMode=$DEPLOY_MODE +) + +if [ -n $NAMESPACE ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.namespace=$NAMESPACE ) +fi + +if [ -n $SERVICE_ACCOUNT ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.serviceAccountName=$SERVICE_ACCOUNT ) +fi + +if [ -n $SPARK_MASTER ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER ) +fi + +if [ -n $EXCLUDE_TAGS ]; +then + properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS ) +fi + +if [ -n $INCLUDE_TAGS ]; +then + properties=( ${properties[@]} -Dtest.include.tags=$INCLUDE_TAGS ) +fi + +../../../build/mvn integration-test ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml b/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml new file mode 100644 index 0000000000000..a4c242f2f2645 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +apiVersion: v1 +kind: Namespace +metadata: + name: spark +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: spark-sa + namespace: spark +--- +apiVersion: rbac.authorization.k8s.io/v1beta1 +kind: ClusterRole +metadata: + name: spark-role +rules: +- apiGroups: + - "" + resources: + - "pods" + verbs: + - "*" +--- +apiVersion: rbac.authorization.k8s.io/v1beta1 +kind: ClusterRoleBinding +metadata: + name: spark-role-binding +subjects: +- kind: ServiceAccount + name: spark-sa + namespace: spark +roleRef: + kind: ClusterRole + name: spark-role + apiGroup: rbac.authorization.k8s.io \ No newline at end of file diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml new file mode 100644 index 0000000000000..29334cc6d891d --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -0,0 +1,171 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.4.0-SNAPSHOT + ../../../pom.xml + + + spark-kubernetes-integration-tests_2.11 + spark-kubernetes-integration-tests + + 1.3.0 + 1.4.0 + + 3.0.0 + 3.2.2 + 1.0 + kubernetes-integration-tests + ${project.build.directory}/spark-dist-unpacked + N/A + ${project.build.directory}/imageTag.txt + minikube + docker.io/kubespark + + + + jar + Spark Project Kubernetes Integration Tests + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + io.fabric8 + kubernetes-client + ${kubernetes-client.version} + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + + + + + + + org.codehaus.mojo + exec-maven-plugin + ${exec-maven-plugin.version} + + + setup-integration-test-env + pre-integration-test + + exec + + + scripts/setup-integration-test-env.sh + + --unpacked-spark-tgz + ${spark.kubernetes.test.unpackSparkDir} + + --image-repo + ${spark.kubernetes.test.imageRepo} + + --image-tag + ${spark.kubernetes.test.imageTag} + + --image-tag-output-file + ${spark.kubernetes.test.imageTagFile} + + --deploy-mode + ${spark.kubernetes.test.deployMode} + + --spark-tgz + ${spark.kubernetes.test.sparkTgz} + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + + + org.scalatest + scalatest-maven-plugin + ${scalatest-maven-plugin.version} + + ${project.build.directory}/surefire-reports + . + SparkTestSuite.txt + -ea -Xmx3g -XX:ReservedCodeCacheSize=512m ${extraScalaTestArgs} + + + file:src/test/resources/log4j.properties + true + ${spark.kubernetes.test.imageTagFile} + ${spark.kubernetes.test.unpackSparkDir} + ${spark.kubernetes.test.imageRepo} + ${spark.kubernetes.test.deployMode} + ${spark.kubernetes.test.master} + ${spark.kubernetes.test.namespace} + ${spark.kubernetes.test.serviceAccountName} + + ${test.exclude.tags} + ${test.include.tags} + + + + test + + test + + + + (?<!Suite) + + + + integration-test + integration-test + + test + + + + + + + + + diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh new file mode 100755 index 0000000000000..ccfb8e767c529 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +TEST_ROOT_DIR=$(git rev-parse --show-toplevel) +UNPACKED_SPARK_TGZ="$TEST_ROOT_DIR/target/spark-dist-unpacked" +IMAGE_TAG_OUTPUT_FILE="$TEST_ROOT_DIR/target/image-tag.txt" +DEPLOY_MODE="minikube" +IMAGE_REPO="docker.io/kubespark" +IMAGE_TAG="N/A" +SPARK_TGZ="N/A" + +# Parse arguments +while (( "$#" )); do + case $1 in + --unpacked-spark-tgz) + UNPACKED_SPARK_TGZ="$2" + shift + ;; + --image-repo) + IMAGE_REPO="$2" + shift + ;; + --image-tag) + IMAGE_TAG="$2" + shift + ;; + --image-tag-output-file) + IMAGE_TAG_OUTPUT_FILE="$2" + shift + ;; + --deploy-mode) + DEPLOY_MODE="$2" + shift + ;; + --spark-tgz) + SPARK_TGZ="$2" + shift + ;; + *) + break + ;; + esac + shift +done + +if [[ $SPARK_TGZ == "N/A" ]]; +then + echo "Must specify a Spark tarball to build Docker images against with --spark-tgz." && exit 1; +fi + +rm -rf $UNPACKED_SPARK_TGZ +mkdir -p $UNPACKED_SPARK_TGZ +tar -xzvf $SPARK_TGZ --strip-components=1 -C $UNPACKED_SPARK_TGZ; + +if [[ $IMAGE_TAG == "N/A" ]]; +then + IMAGE_TAG=$(uuidgen); + cd $UNPACKED_SPARK_TGZ + if [[ $DEPLOY_MODE == cloud ]] ; + then + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build + if [[ $IMAGE_REPO == gcr.io* ]] ; + then + gcloud docker -- push $IMAGE_REPO/spark:$IMAGE_TAG + else + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG push + fi + else + # -m option for minikube. + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG build + fi + cd - +fi + +rm -f $IMAGE_TAG_OUTPUT_FILE +echo -n $IMAGE_TAG > $IMAGE_TAG_OUTPUT_FILE diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties b/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..866126bc3c1c2 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/integration-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/integration-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN +log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala new file mode 100644 index 0000000000000..774c3936b877c --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.File +import java.nio.file.{Path, Paths} +import java.util.UUID +import java.util.regex.Pattern + +import com.google.common.io.PatternFilenameFilter +import io.fabric8.kubernetes.api.model.Pod +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.concurrent.{Eventually, PatienceConfiguration} +import org.scalatest.time.{Minutes, Seconds, Span} +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} +import org.apache.spark.deploy.k8s.integrationtest.config._ +import org.apache.spark.launcher.SparkLauncher + +private[spark] class KubernetesSuite extends SparkFunSuite + with BeforeAndAfterAll with BeforeAndAfter { + + import KubernetesSuite._ + + private var testBackend: IntegrationTestBackend = _ + private var sparkHomeDir: Path = _ + private var kubernetesTestComponents: KubernetesTestComponents = _ + private var sparkAppConf: SparkAppConf = _ + private var image: String = _ + private var pyImage: String = _ + private var containerLocalSparkDistroExamplesJar: String = _ + private var appLocator: String = _ + private var driverPodName: String = _ + + override def beforeAll(): Unit = { + // The scalatest-maven-plugin gives system properties that are referenced but not set null + // values. We need to remove the null-value properties before initializing the test backend. + val nullValueProperties = System.getProperties.asScala + .filter(entry => entry._2.equals("null")) + .map(entry => entry._1.toString) + nullValueProperties.foreach { key => + System.clearProperty(key) + } + + val sparkDirProp = System.getProperty("spark.kubernetes.test.unpackSparkDir") + require(sparkDirProp != null, "Spark home directory must be provided in system properties.") + sparkHomeDir = Paths.get(sparkDirProp) + require(sparkHomeDir.toFile.isDirectory, + s"No directory found for spark home specified at $sparkHomeDir.") + val imageTag = getTestImageTag + val imageRepo = getTestImageRepo + image = s"$imageRepo/spark:$imageTag" + pyImage = s"$imageRepo/spark-py:$imageTag" + + val sparkDistroExamplesJarFile: File = sparkHomeDir.resolve(Paths.get("examples", "jars")) + .toFile + .listFiles(new PatternFilenameFilter(Pattern.compile("^spark-examples_.*\\.jar$")))(0) + containerLocalSparkDistroExamplesJar = s"local:///opt/spark/examples/jars/" + + s"${sparkDistroExamplesJarFile.getName}" + testBackend = IntegrationTestBackendFactory.getTestBackend + testBackend.initialize() + kubernetesTestComponents = new KubernetesTestComponents(testBackend.getKubernetesClient) + } + + override def afterAll(): Unit = { + testBackend.cleanUp() + } + + before { + appLocator = UUID.randomUUID().toString.replaceAll("-", "") + driverPodName = "spark-test-app-" + UUID.randomUUID().toString.replaceAll("-", "") + sparkAppConf = kubernetesTestComponents.newSparkAppConf() + .set("spark.kubernetes.container.image", image) + .set("spark.kubernetes.driver.pod.name", driverPodName) + .set("spark.kubernetes.driver.label.spark-app-locator", appLocator) + .set("spark.kubernetes.executor.label.spark-app-locator", appLocator) + if (!kubernetesTestComponents.hasUserSpecifiedNamespace) { + kubernetesTestComponents.createNamespace() + } + } + + after { + if (!kubernetesTestComponents.hasUserSpecifiedNamespace) { + kubernetesTestComponents.deleteNamespace() + } + deleteDriverPod() + } + + test("Run SparkPi with no resources") { + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with a very long application name.") { + sparkAppConf.set("spark.app.name", "long" * 40) + runSparkPiAndVerifyCompletion() + } + + test("Use SparkLauncher.NO_RESOURCE") { + sparkAppConf.setJars(Seq(containerLocalSparkDistroExamplesJar)) + runSparkPiAndVerifyCompletion( + appResource = SparkLauncher.NO_RESOURCE) + } + + test("Run SparkPi with a master URL without a scheme.") { + val url = kubernetesTestComponents.kubernetesClient.getMasterUrl + val k8sMasterUrl = if (url.getPort < 0) { + s"k8s://${url.getHost}" + } else { + s"k8s://${url.getHost}:${url.getPort}" + } + sparkAppConf.set("spark.master", k8sMasterUrl) + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with an argument.") { + runSparkPiAndVerifyCompletion(appArgs = Array("5")) + } + + test("Run SparkPi with custom labels, annotations, and environment variables.") { + sparkAppConf + .set("spark.kubernetes.driver.label.label1", "label1-value") + .set("spark.kubernetes.driver.label.label2", "label2-value") + .set("spark.kubernetes.driver.annotation.annotation1", "annotation1-value") + .set("spark.kubernetes.driver.annotation.annotation2", "annotation2-value") + .set("spark.kubernetes.driverEnv.ENV1", "VALUE1") + .set("spark.kubernetes.driverEnv.ENV2", "VALUE2") + .set("spark.kubernetes.executor.label.label1", "label1-value") + .set("spark.kubernetes.executor.label.label2", "label2-value") + .set("spark.kubernetes.executor.annotation.annotation1", "annotation1-value") + .set("spark.kubernetes.executor.annotation.annotation2", "annotation2-value") + .set("spark.executorEnv.ENV1", "VALUE1") + .set("spark.executorEnv.ENV2", "VALUE2") + + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkCustomSettings(driverPod) + }, + executorPodChecker = (executorPod: Pod) => { + doBasicExecutorPodCheck(executorPod) + checkCustomSettings(executorPod) + }) + } + + test("Run extraJVMOptions check on driver") { + sparkAppConf + .set("spark.driver.extraJavaOptions", "-Dspark.test.foo=spark.test.bar") + runSparkJVMCheckAndVerifyCompletion( + expectedJVMValue = Seq("(spark.test.foo,spark.test.bar)")) + } + + test("Run SparkRemoteFileTest using a remote data file") { + sparkAppConf + .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) + runSparkRemoteCheckAndVerifyCompletion( + appArgs = Array(REMOTE_PAGE_RANK_FILE_NAME)) + } + + test("Run PySpark on simple pi.py example") { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_PI, + mainClass = "", + expectedLogOnCompletion = Seq("Pi is roughly 3"), + appArgs = Array("5"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false) + } + + test("Run PySpark with Python2 to test a pyfiles example") { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.pyspark.pythonversion", "2") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } + + test("Run PySpark with Python3 to test a pyfiles example") { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.pyspark.pythonversion", "3") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python3"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } + + private def runSparkPiAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, + appArgs: Array[String] = Array.empty[String], + appLocator: String = appLocator, + isJVM: Boolean = true ): Unit = { + runSparkApplicationAndVerifyCompletion( + appResource, + SPARK_PI_MAIN_CLASS, + Seq("Pi is roughly 3"), + appArgs, + driverPodChecker, + executorPodChecker, + appLocator, + isJVM) + } + + private def runSparkRemoteCheckAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, + appArgs: Array[String], + appLocator: String = appLocator): Unit = { + runSparkApplicationAndVerifyCompletion( + appResource, + SPARK_REMOTE_MAIN_CLASS, + Seq(s"Mounting of ${appArgs.head} was true"), + appArgs, + driverPodChecker, + executorPodChecker, + appLocator, + true) + } + + private def runSparkJVMCheckAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + mainClass: String = SPARK_DRIVER_MAIN_CLASS, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + appArgs: Array[String] = Array("5"), + expectedJVMValue: Seq[String]): Unit = { + val appArguments = SparkAppArguments( + mainAppResource = appResource, + mainClass = mainClass, + appArgs = appArgs) + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + true) + + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + doBasicDriverPodCheck(driverPod) + + Eventually.eventually(TIMEOUT, INTERVAL) { + expectedJVMValue.foreach { e => + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPod.getMetadata.getName) + .getLog + .contains(e), "The application did not complete.") + } + } + } + + private def runSparkApplicationAndVerifyCompletion( + appResource: String, + mainClass: String, + expectedLogOnCompletion: Seq[String], + appArgs: Array[String], + driverPodChecker: Pod => Unit, + executorPodChecker: Pod => Unit, + appLocator: String, + isJVM: Boolean, + pyFiles: Option[String] = None): Unit = { + val appArguments = SparkAppArguments( + mainAppResource = appResource, + mainClass = mainClass, + appArgs = appArgs) + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + isJVM, + pyFiles) + + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + driverPodChecker(driverPod) + + val executorPods = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "executor") + .list() + .getItems + executorPods.asScala.foreach { pod => + executorPodChecker(pod) + } + + Eventually.eventually(TIMEOUT, INTERVAL) { + expectedLogOnCompletion.foreach { e => + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPod.getMetadata.getName) + .getLog + .contains(e), "The application did not complete.") + } + } + } + + private def doBasicDriverPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === image) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + } + + private def doBasicDriverPyPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === pyImage) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + } + + private def doBasicExecutorPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === image) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + } + + private def doBasicExecutorPyPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === pyImage) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + } + + private def checkCustomSettings(pod: Pod): Unit = { + assert(pod.getMetadata.getLabels.get("label1") === "label1-value") + assert(pod.getMetadata.getLabels.get("label2") === "label2-value") + assert(pod.getMetadata.getAnnotations.get("annotation1") === "annotation1-value") + assert(pod.getMetadata.getAnnotations.get("annotation2") === "annotation2-value") + + val container = pod.getSpec.getContainers.get(0) + val envVars = container + .getEnv + .asScala + .map { env => + (env.getName, env.getValue) + } + .toMap + assert(envVars("ENV1") === "VALUE1") + assert(envVars("ENV2") === "VALUE2") + } + + private def deleteDriverPod(): Unit = { + kubernetesTestComponents.kubernetesClient.pods().withName(driverPodName).delete() + Eventually.eventually(TIMEOUT, INTERVAL) { + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPodName) + .get() == null) + } + } +} + +private[spark] object KubernetesSuite { + + val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) + val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) + val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi" + val SPARK_REMOTE_MAIN_CLASS: String = "org.apache.spark.examples.SparkRemoteFileTest" + val SPARK_DRIVER_MAIN_CLASS: String = "org.apache.spark.examples.DriverSubmissionTest" + val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank" + val CONTAINER_LOCAL_PYSPARK: String = "local:///opt/spark/examples/src/main/python/" + val PYSPARK_PI: String = CONTAINER_LOCAL_PYSPARK + "pi.py" + val PYSPARK_FILES: String = CONTAINER_LOCAL_PYSPARK + "pyfiles.py" + val PYSPARK_CONTAINER_TESTS: String = CONTAINER_LOCAL_PYSPARK + "py_container_checks.py" + + val TEST_SECRET_NAME_PREFIX = "test-secret-" + val TEST_SECRET_KEY = "test-key" + val TEST_SECRET_VALUE = "test-data" + val TEST_SECRET_MOUNT_PATH = "/etc/secrets" + + val REMOTE_PAGE_RANK_DATA_FILE = + "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" + val REMOTE_PAGE_RANK_FILE_NAME = "pagerank_data.txt" + + case object ShuffleNotReadyException extends Exception +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala new file mode 100644 index 0000000000000..a9b49a8e5a610 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.nio.file.{Path, Paths} +import java.util.UUID + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import io.fabric8.kubernetes.client.DefaultKubernetesClient +import org.scalatest.concurrent.Eventually + +import org.apache.spark.internal.Logging + +private[spark] class KubernetesTestComponents(defaultClient: DefaultKubernetesClient) { + + val namespaceOption = Option(System.getProperty("spark.kubernetes.test.namespace")) + val hasUserSpecifiedNamespace = namespaceOption.isDefined + val namespace = namespaceOption.getOrElse(UUID.randomUUID().toString.replaceAll("-", "")) + private val serviceAccountName = + Option(System.getProperty("spark.kubernetes.test.serviceAccountName")) + .getOrElse("default") + val kubernetesClient = defaultClient.inNamespace(namespace) + val clientConfig = kubernetesClient.getConfiguration + + def createNamespace(): Unit = { + defaultClient.namespaces.createNew() + .withNewMetadata() + .withName(namespace) + .endMetadata() + .done() + } + + def deleteNamespace(): Unit = { + defaultClient.namespaces.withName(namespace).delete() + Eventually.eventually(KubernetesSuite.TIMEOUT, KubernetesSuite.INTERVAL) { + val namespaceList = defaultClient + .namespaces() + .list() + .getItems + .asScala + require(!namespaceList.exists(_.getMetadata.getName == namespace)) + } + } + + def newSparkAppConf(): SparkAppConf = { + new SparkAppConf() + .set("spark.master", s"k8s://${kubernetesClient.getMasterUrl}") + .set("spark.kubernetes.namespace", namespace) + .set("spark.executor.memory", "500m") + .set("spark.executor.cores", "1") + .set("spark.executors.instances", "1") + .set("spark.app.name", "spark-test-app") + .set("spark.ui.enabled", "true") + .set("spark.testing", "false") + .set("spark.kubernetes.submission.waitAppCompletion", "false") + .set("spark.kubernetes.authenticate.driver.serviceAccountName", serviceAccountName) + } +} + +private[spark] class SparkAppConf { + + private val map = mutable.Map[String, String]() + + def set(key: String, value: String): SparkAppConf = { + map.put(key, value) + this + } + + def get(key: String): String = map.getOrElse(key, "") + + def setJars(jars: Seq[String]): Unit = set("spark.jars", jars.mkString(",")) + + override def toString: String = map.toString + + def toStringArray: Iterable[String] = map.toList.flatMap(t => List("--conf", s"${t._1}=${t._2}")) +} + +private[spark] case class SparkAppArguments( + mainAppResource: String, + mainClass: String, + appArgs: Array[String]) + +private[spark] object SparkAppLauncher extends Logging { + def launch( + appArguments: SparkAppArguments, + appConf: SparkAppConf, + timeoutSecs: Int, + sparkHomeDir: Path, + isJVM: Boolean, + pyFiles: Option[String] = None): Unit = { + val sparkSubmitExecutable = sparkHomeDir.resolve(Paths.get("bin", "spark-submit")) + logInfo(s"Launching a spark app with arguments $appArguments and conf $appConf") + val preCommandLine = if (isJVM) { + mutable.ArrayBuffer(sparkSubmitExecutable.toFile.getAbsolutePath, + "--deploy-mode", "cluster", + "--class", appArguments.mainClass, + "--master", appConf.get("spark.master")) + } else { + mutable.ArrayBuffer(sparkSubmitExecutable.toFile.getAbsolutePath, + "--deploy-mode", "cluster", + "--master", appConf.get("spark.master")) + } + val commandLine = + pyFiles.map(s => preCommandLine ++ Array("--py-files", s)).getOrElse(preCommandLine) ++ + appConf.toStringArray :+ appArguments.mainAppResource + + if (appArguments.appArgs.nonEmpty) { + commandLine += appArguments.appArgs.mkString(" ") + } + logInfo(s"Launching a spark app with command line: ${commandLine.mkString(" ")}") + ProcessUtils.executeProcess(commandLine.toArray, timeoutSecs) + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala new file mode 100644 index 0000000000000..d8f3a6cec05c3 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable.ArrayBuffer +import scala.io.Source + +import org.apache.spark.internal.Logging + +object ProcessUtils extends Logging { + /** + * executeProcess is used to run a command and return the output if it + * completes within timeout seconds. + */ + def executeProcess(fullCommand: Array[String], timeout: Long): Seq[String] = { + val pb = new ProcessBuilder().command(fullCommand: _*) + pb.redirectErrorStream(true) + val proc = pb.start() + val outputLines = new ArrayBuffer[String] + Utils.tryWithResource(proc.getInputStream)( + Source.fromInputStream(_, "UTF-8").getLines().foreach { line => + logInfo(line) + outputLines += line + }) + assert(proc.waitFor(timeout, TimeUnit.SECONDS), + s"Timed out while executing ${fullCommand.mkString(" ")}") + assert(proc.exitValue == 0, s"Failed to execute ${fullCommand.mkString(" ")}") + outputLines + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala new file mode 100644 index 0000000000000..f1fd6dc19ce54 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.util.concurrent.TimeUnit + +import com.google.common.util.concurrent.SettableFuture +import io.fabric8.kubernetes.api.model.HasMetadata +import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import io.fabric8.kubernetes.client.internal.readiness.Readiness + +private[spark] class SparkReadinessWatcher[T <: HasMetadata] extends Watcher[T] { + + private val signal = SettableFuture.create[Boolean] + + override def eventReceived(action: Action, resource: T): Unit = { + if ((action == Action.MODIFIED || action == Action.ADDED) && + Readiness.isReady(resource)) { + signal.set(true) + } + } + + override def onClose(cause: KubernetesClientException): Unit = {} + + def waitUntilReady(): Boolean = signal.get(60, TimeUnit.SECONDS) +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala new file mode 100644 index 0000000000000..663f8b6523ac8 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.Closeable +import java.net.URI + +import org.apache.spark.internal.Logging + +object Utils extends Logging { + + def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { + val resource = createResource + try f.apply(resource) finally resource.close() + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala new file mode 100644 index 0000000000000..284712c6d250e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s.integrationtest.backend + +import io.fabric8.kubernetes.client.DefaultKubernetesClient + +import org.apache.spark.deploy.k8s.integrationtest.backend.minikube.MinikubeTestBackend + +private[spark] trait IntegrationTestBackend { + def initialize(): Unit + def getKubernetesClient: DefaultKubernetesClient + def cleanUp(): Unit = {} +} + +private[spark] object IntegrationTestBackendFactory { + val deployModeConfigKey = "spark.kubernetes.test.deployMode" + + def getTestBackend: IntegrationTestBackend = { + val deployMode = Option(System.getProperty(deployModeConfigKey)) + .getOrElse("minikube") + if (deployMode == "minikube") { + MinikubeTestBackend + } else { + throw new IllegalArgumentException( + "Invalid " + deployModeConfigKey + ": " + deployMode) + } + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala new file mode 100644 index 0000000000000..6494cbc18f33e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest.backend.minikube + +import java.io.File +import java.nio.file.Paths + +import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient} + +import org.apache.spark.deploy.k8s.integrationtest.ProcessUtils +import org.apache.spark.internal.Logging + +// TODO support windows +private[spark] object Minikube extends Logging { + + private val MINIKUBE_STARTUP_TIMEOUT_SECONDS = 60 + + def getMinikubeIp: String = { + val outputs = executeMinikube("ip") + .filter(_.matches("^\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}$")) + assert(outputs.size == 1, "Unexpected amount of output from minikube ip") + outputs.head + } + + def getMinikubeStatus: MinikubeStatus.Value = { + val statusString = executeMinikube("status") + .filter(line => line.contains("minikubeVM: ") || line.contains("minikube:")) + .head + .replaceFirst("minikubeVM: ", "") + .replaceFirst("minikube: ", "") + MinikubeStatus.unapply(statusString) + .getOrElse(throw new IllegalStateException(s"Unknown status $statusString")) + } + + def getKubernetesClient: DefaultKubernetesClient = { + val kubernetesMaster = s"https://${getMinikubeIp}:8443" + val userHome = System.getProperty("user.home") + val kubernetesConf = new ConfigBuilder() + .withApiVersion("v1") + .withMasterUrl(kubernetesMaster) + .withCaCertFile(Paths.get(userHome, ".minikube", "ca.crt").toFile.getAbsolutePath) + .withClientCertFile(Paths.get(userHome, ".minikube", "apiserver.crt").toFile.getAbsolutePath) + .withClientKeyFile(Paths.get(userHome, ".minikube", "apiserver.key").toFile.getAbsolutePath) + .build() + new DefaultKubernetesClient(kubernetesConf) + } + + private def executeMinikube(action: String, args: String*): Seq[String] = { + ProcessUtils.executeProcess( + Array("bash", "-c", s"minikube $action") ++ args, MINIKUBE_STARTUP_TIMEOUT_SECONDS) + } +} + +private[spark] object MinikubeStatus extends Enumeration { + + // The following states are listed according to + // https://github.com/docker/machine/blob/master/libmachine/state/state.go. + val STARTING = status("Starting") + val RUNNING = status("Running") + val PAUSED = status("Paused") + val STOPPING = status("Stopping") + val STOPPED = status("Stopped") + val ERROR = status("Error") + val TIMEOUT = status("Timeout") + val SAVED = status("Saved") + val NONE = status("") + + def status(value: String): Value = new Val(nextId, value) + def unapply(s: String): Option[Value] = values.find(s == _.toString) +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala new file mode 100644 index 0000000000000..cb9324179d70e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest.backend.minikube + +import io.fabric8.kubernetes.client.DefaultKubernetesClient + +import org.apache.spark.deploy.k8s.integrationtest.backend.IntegrationTestBackend + +private[spark] object MinikubeTestBackend extends IntegrationTestBackend { + + private var defaultClient: DefaultKubernetesClient = _ + + override def initialize(): Unit = { + val minikubeStatus = Minikube.getMinikubeStatus + require(minikubeStatus == MinikubeStatus.RUNNING, + s"Minikube must be running to use the Minikube backend for integration tests." + + s" Current status is: $minikubeStatus.") + defaultClient = Minikube.getKubernetesClient + } + + override def cleanUp(): Unit = { + super.cleanUp() + } + + override def getKubernetesClient: DefaultKubernetesClient = { + defaultClient + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala new file mode 100644 index 0000000000000..a81ef455c6766 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files + +package object config { + def getTestImageTag: String = { + val imageTagFileProp = System.getProperty("spark.kubernetes.test.imageTagFile") + require(imageTagFileProp != null, "Image tag file must be provided in system properties.") + val imageTagFile = new File(imageTagFileProp) + require(imageTagFile.isFile, s"No file found for image tag at ${imageTagFile.getAbsolutePath}.") + Files.toString(imageTagFile, Charsets.UTF_8).trim + } + + def getTestImageRepo: String = { + val imageRepo = System.getProperty("spark.kubernetes.test.imageRepo") + require(imageRepo != null, "Image repo must be provided in system properties.") + imageRepo + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala new file mode 100644 index 0000000000000..0807a68cd823c --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +package object constants { + val MINIKUBE_TEST_BACKEND = "minikube" + val GCE_TEST_BACKEND = "gce" +} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index 022191d0070fd..91f64141e5318 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -39,7 +39,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")

      Cannot find driver {driverId}

      - return UIUtils.basicSparkPage(content, s"Details for Job $driverId") + return UIUtils.basicSparkPage(request, content, s"Details for Job $driverId") } val driverState = state.get val driverHeaders = Seq("Driver property", "Value") @@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) val content =

      Driver state information for driver id {driverId}

      - Back to Drivers + Back to Drivers

      Driver state: {driverState.state}

      @@ -87,7 +87,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
      ; - UIUtils.basicSparkPage(content, s"Details for Job $driverId") + UIUtils.basicSparkPage(request, content, s"Details for Job $driverId") } private def launchedRow(submissionState: Option[MesosClusterSubmissionState]): Seq[Node] = { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 88a6614d51384..c53285331ea68 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -62,7 +62,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( {retryTable} ; - UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster") + UIUtils.basicSparkPage(request, content, "Spark Drivers for Mesos cluster") } private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala index 604978967d6db..15bbe60d6c8fb 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -40,7 +40,7 @@ private[spark] class MesosClusterUI( override def initialize() { attachPage(new MesosClusterPage(this)) attachPage(new DriverPage(this)) - attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR) } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index d224a7325820a..7d80eedcc43ce 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -30,8 +30,7 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} -import org.apache.spark.deploy.mesos.MesosDriverDescription -import org.apache.spark.deploy.mesos.config +import org.apache.spark.deploy.mesos.{config, MesosDriverDescription} import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils @@ -418,6 +417,18 @@ private[spark] class MesosClusterScheduler( envBuilder.build() } + private def isContainerLocalAppJar(desc: MesosDriverDescription): Boolean = { + val isLocalJar = desc.jarUrl.startsWith("local://") + val isContainerLocal = desc.conf.getOption("spark.mesos.appJar.local.resolution.mode").exists { + case "container" => true + case "host" => false + case other => + logWarning(s"Unknown spark.mesos.appJar.local.resolution.mode $other, using host.") + false + } + isLocalJar && isContainerLocal + } + private def getDriverUris(desc: MesosDriverDescription): List[CommandInfo.URI] = { val confUris = List(conf.getOption("spark.mesos.uris"), desc.conf.getOption("spark.mesos.uris"), @@ -425,10 +436,14 @@ private[spark] class MesosClusterScheduler( _.map(_.split(",").map(_.trim)) ).flatten - val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:") - - ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri => - CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + if (isContainerLocalAppJar(desc)) { + (confUris ++ getDriverExecutorURI(desc).toList).map(uri => + CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + } else { + val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:") + ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri => + CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + } } private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = { @@ -480,7 +495,14 @@ private[spark] class MesosClusterScheduler( (cmdExecutable, ".") } val cmdOptions = generateCmdOption(desc, sandboxPath).mkString(" ") - val primaryResource = new File(sandboxPath, desc.jarUrl.split("/").last).toString() + val primaryResource = { + if (isContainerLocalAppJar(desc)) { + new File(desc.jarUrl.stripPrefix("local://")).toString() + } else { + new File(sandboxPath, desc.jarUrl.split("/").last).toString() + } + } + val appArguments = desc.command.arguments.mkString(" ") s"$executable $cmdOptions $primaryResource $appArguments" @@ -530,9 +552,9 @@ private[spark] class MesosClusterScheduler( .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } .toMap (defaultConf ++ driverConf).foreach { case (key, value) => - options ++= Seq("--conf", s""""$key=${shellEscape(value)}"""".stripMargin) } + options ++= Seq("--conf", s"${key}=${value}") } - options + options.map(shellEscape) } /** diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 53f5f61cca486..1ce2f816dffb2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -227,7 +227,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") + val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, taskId) + }.getOrElse("") // Set the environment variable through a command prefix // to append to the existing value of the variable @@ -632,7 +634,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( slave.hostname, externalShufflePort, sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", - s"${sc.conf.getTimeAsMs("spark.network.timeout", "120s")}ms"), + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s")}s"), sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) slave.shuffleRegistered = true } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index d6d939d246109..71a70ff048ccc 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -111,7 +111,9 @@ private[spark] class MesosFineGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") + val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, execId) + }.getOrElse("") val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => Utils.libraryPathEnvPrefix(Seq(p)) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index f4bd1ee9da6f7..b790c7cd27794 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -789,6 +789,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.nodeBlacklist).thenReturn(Set[String]()) when(taskScheduler.sc).thenReturn(sc) externalShuffleClient = mock[MesosExternalShuffleClient] diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index d04989e138f83..ecc576910db9e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} -import java.lang.reflect.InvocationTargetException +import java.lang.reflect.{InvocationTargetException, Modifier} import java.net.{Socket, URI, URL} import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} @@ -308,7 +308,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logError("Uncaught exception: ", e) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, - "Uncaught exception: " + e) + "Uncaught exception: " + StringUtils.stringifyException(e)) } } @@ -346,7 +346,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends synchronized { if (!finished) { val inShutdown = ShutdownHookManager.inShutdown() - if (registered) { + if (registered || !isClusterMode) { exitCode = code finalStatus = status } else { @@ -389,37 +389,40 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } private def registerAM( + host: String, + port: Int, _sparkConf: SparkConf, - _rpcEnv: RpcEnv, - driverRef: RpcEndpointRef, - uiAddress: Option[String]) = { + uiAddress: Option[String]): Unit = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() val historyAddress = ApplicationMaster .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) - val driverUrl = RpcEndpointAddress( - _sparkConf.get("spark.driver.host"), - _sparkConf.get("spark.driver.port").toInt, + client.register(host, port, yarnConf, _sparkConf, uiAddress, historyAddress) + registered = true + } + + private def createAllocator(driverRef: RpcEndpointRef, _sparkConf: SparkConf): Unit = { + val appId = client.getAttemptId().getApplicationId().toString() + val driverUrl = RpcEndpointAddress(driverRef.address.host, driverRef.address.port, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString // Before we initialize the allocator, let's log the information about how executors will // be run up front, to avoid printing this out for every single executor being launched. // Use placeholders for information that changes such as executor IDs. logInfo { - val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt - val executorCores = sparkConf.get(EXECUTOR_CORES) - val dummyRunner = new ExecutorRunnable(None, yarnConf, sparkConf, driverUrl, "", + val executorMemory = _sparkConf.get(EXECUTOR_MEMORY).toInt + val executorCores = _sparkConf.get(EXECUTOR_CORES) + val dummyRunner = new ExecutorRunnable(None, yarnConf, _sparkConf, driverUrl, "", "", executorMemory, executorCores, appId, securityMgr, localResources) dummyRunner.launchContextDebugInfo() } - allocator = client.register(driverUrl, - driverRef, + allocator = client.createAllocator( yarnConf, _sparkConf, - uiAddress, - historyAddress, + driverUrl, + driverRef, securityMgr, localResources) @@ -434,15 +437,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends reporterThread = launchReporterThread() } - /** - * @return An [[RpcEndpoint]] that communicates with the driver's scheduler backend. - */ - private def createSchedulerRef(host: String, port: String): RpcEndpointRef = { - rpcEnv.setupEndpointRef( - RpcAddress(host, port.toInt), - YarnSchedulerBackend.ENDPOINT_NAME) - } - private def runDriver(): Unit = { addAmIpFilter(None) userClassThread = startUserApplication() @@ -456,11 +450,16 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends Duration(totalWaitTime, TimeUnit.MILLISECONDS)) if (sc != null) { rpcEnv = sc.env.rpcEnv - val driverRef = createSchedulerRef( - sc.getConf.get("spark.driver.host"), - sc.getConf.get("spark.driver.port")) - registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl)) - registered = true + + val userConf = sc.getConf + val host = userConf.get("spark.driver.host") + val port = userConf.get("spark.driver.port").toInt + registerAM(host, port, userConf, sc.ui.map(_.webUrl)) + + val driverRef = rpcEnv.setupEndpointRef( + RpcAddress(host, port), + YarnSchedulerBackend.ENDPOINT_NAME) + createAllocator(driverRef, userConf) } else { // Sanity check; should never happen in normal operation, since sc should only be null // if the user app did not create a SparkContext. @@ -486,10 +485,18 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val amCores = sparkConf.get(AM_CORES) rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr, amCores, true) - val driverRef = waitForSparkDriver() + + // The client-mode AM doesn't listen for incoming connections, so report an invalid port. + registerAM(hostname, -1, sparkConf, sparkConf.getOption("spark.driver.appUIAddress")) + + // The driver should be up and listening, so unlike cluster mode, just try to connect to it + // with no waiting or retrying. + val (driverHost, driverPort) = Utils.parseHostPort(args.userArgs(0)) + val driverRef = rpcEnv.setupEndpointRef( + RpcAddress(driverHost, driverPort), + YarnSchedulerBackend.ENDPOINT_NAME) addAmIpFilter(Some(driverRef)) - registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress")) - registered = true + createAllocator(driverRef, sparkConf) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -508,6 +515,10 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, s"Max number of executor failures ($maxNumExecutorFailures) reached") + } else if (allocator.isAllNodeBlacklisted) { + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, + "Due to executor failures all available nodes are blacklisted") } else { logDebug("Sending progress") allocator.allocateResources() @@ -600,40 +611,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - private def waitForSparkDriver(): RpcEndpointRef = { - logInfo("Waiting for Spark driver to be reachable.") - var driverUp = false - val hostport = args.userArgs(0) - val (driverHost, driverPort) = Utils.parseHostPort(hostport) - - // Spark driver should already be up since it launched us, but we don't want to - // wait forever, so wait 100 seconds max to match the cluster mode setting. - val totalWaitTimeMs = sparkConf.get(AM_MAX_WAIT_TIME) - val deadline = System.currentTimeMillis + totalWaitTimeMs - - while (!driverUp && !finished && System.currentTimeMillis < deadline) { - try { - val socket = new Socket(driverHost, driverPort) - socket.close() - logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) - driverUp = true - } catch { - case e: Exception => - logError("Failed to connect to driver at %s:%s, retrying ...". - format(driverHost, driverPort)) - Thread.sleep(100L) - } - } - - if (!driverUp) { - throw new SparkException("Failed to connect to driver!") - } - - sparkConf.set("spark.driver.host", driverHost) - sparkConf.set("spark.driver.port", driverPort.toString) - createSchedulerRef(driverHost, driverPort.toString) - } - /** Add the Yarn IP filter that is required for properly securing the UI. */ private def addAmIpFilter(driver: Option[RpcEndpointRef]) = { val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) @@ -675,9 +652,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val userThread = new Thread { override def run() { try { - mainMethod.invoke(null, userArgs.toArray) - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - logDebug("Done running users class") + if (!Modifier.isStatic(mainMethod.getModifiers)) { + logError(s"Could not find static main method in object ${args.userClass}") + finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS) + } else { + mainMethod.invoke(null, userArgs.toArray) + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + logDebug("Done running user class") + } } catch { case e: InvocationTargetException => e.getCause match { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 5763c3dbc5a8a..793d012218490 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -892,12 +892,15 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { - prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) + prefixEnv = Some(createLibraryPathPrefix(libraryPaths.mkString(File.pathSeparator), + sparkConf)) } if (sparkConf.get(AM_JAVA_OPTIONS).isDefined) { logWarning(s"${AM_JAVA_OPTIONS.key} will not take effect in cluster mode") @@ -914,10 +917,12 @@ private[spark] class Client( s"(was '$opts'). Use spark.yarn.am.memory instead." throw new SparkException(msg) } - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(AM_LIBRARY_PATH).foreach { paths => - prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) + prefixEnv = Some(createLibraryPathPrefix(paths, sparkConf)) } } @@ -1015,8 +1020,7 @@ private[spark] class Client( appId: ApplicationId, returnOnRunning: Boolean = false, logApplicationReport: Boolean = true, - interval: Long = sparkConf.get(REPORT_INTERVAL)): - (YarnApplicationState, FinalApplicationStatus) = { + interval: Long = sparkConf.get(REPORT_INTERVAL)): YarnAppReport = { var lastState: YarnApplicationState = null while (true) { Thread.sleep(interval) @@ -1027,11 +1031,13 @@ private[spark] class Client( case e: ApplicationNotFoundException => logError(s"Application $appId not found.") cleanupStagingDir(appId) - return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + return YarnAppReport(YarnApplicationState.KILLED, FinalApplicationStatus.KILLED, None) case NonFatal(e) => - logError(s"Failed to contact YARN for application $appId.", e) + val msg = s"Failed to contact YARN for application $appId." + logError(msg, e) // Don't necessarily clean up staging dir because status is unknown - return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) + return YarnAppReport(YarnApplicationState.FAILED, FinalApplicationStatus.FAILED, + Some(msg)) } val state = report.getYarnApplicationState @@ -1069,14 +1075,14 @@ private[spark] class Client( } if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { cleanupStagingDir(appId) - return (state, report.getFinalApplicationStatus) + return createAppReport(report) } if (returnOnRunning && state == YarnApplicationState.RUNNING) { - return (state, report.getFinalApplicationStatus) + return createAppReport(report) } lastState = state @@ -1125,16 +1131,17 @@ private[spark] class Client( throw new SparkException(s"Application $appId finished with status: $state") } } else { - val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId) - if (yarnApplicationState == YarnApplicationState.FAILED || - finalApplicationStatus == FinalApplicationStatus.FAILED) { + val YarnAppReport(appState, finalState, diags) = monitorApplication(appId) + if (appState == YarnApplicationState.FAILED || finalState == FinalApplicationStatus.FAILED) { + diags.foreach { err => + logError(s"Application diagnostics message: $err") + } throw new SparkException(s"Application $appId finished with failed status") } - if (yarnApplicationState == YarnApplicationState.KILLED || - finalApplicationStatus == FinalApplicationStatus.KILLED) { + if (appState == YarnApplicationState.KILLED || finalState == FinalApplicationStatus.KILLED) { throw new SparkException(s"Application $appId is killed") } - if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { + if (finalState == FinalApplicationStatus.UNDEFINED) { throw new SparkException(s"The final status of application $appId is undefined") } } @@ -1148,7 +1155,7 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.10.6-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.10.7-src.zip") require(py4jFile.exists(), s"$py4jFile not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) @@ -1473,6 +1480,29 @@ private object Client extends Logging { uri.startsWith(s"$LOCAL_SCHEME:") } + def createAppReport(report: ApplicationReport): YarnAppReport = { + val diags = report.getDiagnostics() + val diagsOpt = if (diags != null && diags.nonEmpty) Some(diags) else None + YarnAppReport(report.getYarnApplicationState(), report.getFinalApplicationStatus(), diagsOpt) + } + + /** + * Create a properly quoted and escaped library path string to be added as a prefix to the command + * executed by YARN. This is different from normal quoting / escaping due to YARN executing the + * command through "bash -c". + */ + def createLibraryPathPrefix(libpath: String, conf: SparkConf): String = { + val cmdPrefix = if (Utils.isWindows) { + Utils.libraryPathEnvPrefix(Seq(libpath)) + } else { + val envName = Utils.libraryPathEnvName + // For quotes, escape both the quote and the escape character when encoding in the command + // string. + val quoted = libpath.replace("\"", "\\\\\\\"") + envName + "=\\\"" + quoted + File.pathSeparator + "$" + envName + "\\\"" + } + getClusterPath(conf, cmdPrefix) + } } private[spark] class YarnClusterApplication extends SparkApplication { @@ -1487,3 +1517,8 @@ private[spark] class YarnClusterApplication extends SparkApplication { } } + +private[spark] case class YarnAppReport( + appState: YarnApplicationState, + finalState: FinalApplicationStatus, + diagnostics: Option[String]) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ab08698035c98..49a0b93aa5c40 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -131,20 +131,20 @@ private[yarn] class ExecutorRunnable( // Extra options for the JVM val javaOpts = ListBuffer[String]() - // Set the environment variable through a command prefix - // to append to the existing value of the variable - var prefixEnv: Option[String] = None - // Set the JVM memory val executorMemoryString = executorMemory + "m" javaOpts += "-Xmx" + executorMemoryString // Set extra Java options for the executor, if defined sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + val subsOpt = Utils.substituteAppNExecIds(opts, appId, executorId) + javaOpts ++= Utils.splitCommandString(subsOpt).map(YarnSparkHadoopUtil.escapeForShell) } - sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => - prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) + + // Set the library path through a command prefix to append to the existing value of the + // env variable. + val prefixEnv = sparkConf.get(EXECUTOR_LIBRARY_PATH).map { libPath => + Client.createLibraryPathPrefix(libPath, sparkConf) } javaOpts += "-Djava.io.tmpdir=" + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index ebee3d431744d..fae054e0eea00 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -24,7 +24,7 @@ import java.util.regex.Pattern import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.control.NonFatal import org.apache.hadoop.yarn.api.records._ @@ -66,7 +66,8 @@ private[yarn] class YarnAllocator( appAttemptId: ApplicationAttemptId, securityMgr: SecurityManager, localResources: Map[String, LocalResource], - resolver: SparkRackResolver) + resolver: SparkRackResolver, + clock: Clock = new SystemClock) extends Logging { import YarnAllocator._ @@ -102,18 +103,14 @@ private[yarn] class YarnAllocator( private var executorIdCounter: Int = driverRef.askSync[Int](RetrieveLastAllocatedExecutorId) - // Queue to store the timestamp of failed executors - private val failedExecutorsTimeStamps = new Queue[Long]() + private[spark] val failureTracker = new FailureTracker(sparkConf, clock) - private var clock: Clock = new SystemClock - - private val executorFailuresValidityInterval = - sparkConf.get(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) + private val allocatorBlacklistTracker = + new YarnAllocatorBlacklistTracker(sparkConf, amClient, failureTracker) @volatile private var targetNumExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(sparkConf) - private var currentNodeBlacklist = Set.empty[String] // Executor loss reason requests that are pending - maps from executor ID for inquiry to a // list of requesters that should be responded to once we find out why the given executor @@ -149,7 +146,6 @@ private[yarn] class YarnAllocator( private val labelExpression = sparkConf.get(EXECUTOR_NODE_LABEL_EXPRESSION) - // A map to store preferred hostname and possible task numbers running on it. private var hostToLocalTaskCounts: Map[String, Int] = Map.empty @@ -160,26 +156,11 @@ private[yarn] class YarnAllocator( private[yarn] val containerPlacementStrategy = new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource, resolver) - /** - * Use a different clock for YarnAllocator. This is mainly used for testing. - */ - def setClock(newClock: Clock): Unit = { - clock = newClock - } - def getNumExecutorsRunning: Int = runningExecutors.size() - def getNumExecutorsFailed: Int = synchronized { - val endTime = clock.getTimeMillis() + def getNumExecutorsFailed: Int = failureTracker.numFailedExecutors - while (executorFailuresValidityInterval > 0 - && failedExecutorsTimeStamps.nonEmpty - && failedExecutorsTimeStamps.head < endTime - executorFailuresValidityInterval) { - failedExecutorsTimeStamps.dequeue() - } - - failedExecutorsTimeStamps.size - } + def isAllNodeBlacklisted: Boolean = allocatorBlacklistTracker.isAllNodeBlacklisted /** * A sequence of pending container requests that have not yet been fulfilled. @@ -204,9 +185,8 @@ private[yarn] class YarnAllocator( * @param localityAwareTasks number of locality aware tasks to be used as container placement hint * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as * container placement hint. - * @param nodeBlacklist a set of blacklisted nodes, which is passed in to avoid allocating new - * containers on them. It will be used to update the application master's - * blacklist. + * @param nodeBlacklist blacklisted nodes, which is passed in to avoid allocating new containers + * on them. It will be used to update the application master's blacklist. * @return Whether the new requested total is different than the old value. */ def requestTotalExecutorsWithPreferredLocalities( @@ -220,19 +200,7 @@ private[yarn] class YarnAllocator( if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal - - // Update blacklist infomation to YARN ResouceManager for this application, - // in order to avoid allocating new Containers on the problematic nodes. - val blacklistAdditions = nodeBlacklist -- currentNodeBlacklist - val blacklistRemovals = currentNodeBlacklist -- nodeBlacklist - if (blacklistAdditions.nonEmpty) { - logInfo(s"adding nodes to YARN application master's blacklist: $blacklistAdditions") - } - if (blacklistRemovals.nonEmpty) { - logInfo(s"removing nodes from YARN application master's blacklist: $blacklistRemovals") - } - amClient.updateBlacklist(blacklistAdditions.toList.asJava, blacklistRemovals.toList.asJava) - currentNodeBlacklist = nodeBlacklist + allocatorBlacklistTracker.setSchedulerBlacklistedNodes(nodeBlacklist) true } else { false @@ -268,6 +236,7 @@ private[yarn] class YarnAllocator( val allocateResponse = amClient.allocate(progressIndicator) val allocatedContainers = allocateResponse.getAllocatedContainers() + allocatorBlacklistTracker.setNumClusterNodes(allocateResponse.getNumClusterNodes) if (allocatedContainers.size > 0) { logDebug(("Allocated containers: %d. Current executor count: %d. " + @@ -602,8 +571,9 @@ private[yarn] class YarnAllocator( completedContainer.getDiagnostics, PMEM_EXCEEDED_PATTERN)) case _ => - // Enqueue the timestamp of failed executor - failedExecutorsTimeStamps.enqueue(clock.getTimeMillis()) + // all the failures which not covered above, like: + // disk failure, kill by app master or resource manager, ... + allocatorBlacklistTracker.handleResourceAllocationFailure(hostOpt) (true, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + ". Diagnostics: " + completedContainer.getDiagnostics) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala new file mode 100644 index 0000000000000..1b48a0ee7ad32 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.HashMap + +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.scheduler.BlacklistTracker +import org.apache.spark.util.{Clock, SystemClock, Utils} + +/** + * YarnAllocatorBlacklistTracker is responsible for tracking the blacklisted nodes + * and synchronizing the node list to YARN. + * + * Blacklisted nodes are coming from two different sources: + * + *
        + *
      • from the scheduler as task level blacklisted nodes + *
      • from this class (tracked here) as YARN resource allocation problems + *
      + * + * The reason to realize this logic here (and not in the driver) is to avoid possible delays + * between synchronizing the blacklisted nodes with YARN and resource allocations. + */ +private[spark] class YarnAllocatorBlacklistTracker( + sparkConf: SparkConf, + amClient: AMRMClient[ContainerRequest], + failureTracker: FailureTracker) + extends Logging { + + private val blacklistTimeoutMillis = BlacklistTracker.getBlacklistTimeout(sparkConf) + + private val launchBlacklistEnabled = sparkConf.get(YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED) + + private val maxFailuresPerHost = sparkConf.get(MAX_FAILED_EXEC_PER_NODE) + + private val allocatorBlacklist = new HashMap[String, Long]() + + private var currentBlacklistedYarnNodes = Set.empty[String] + + private var schedulerBlacklist = Set.empty[String] + + private var numClusterNodes = Int.MaxValue + + def setNumClusterNodes(numClusterNodes: Int): Unit = { + this.numClusterNodes = numClusterNodes + } + + def handleResourceAllocationFailure(hostOpt: Option[String]): Unit = { + hostOpt match { + case Some(hostname) if launchBlacklistEnabled => + // failures on an already blacklisted nodes are not even tracked. + // otherwise, such failures could shutdown the application + // as resource requests are asynchronous + // and a late failure response could exceed MAX_EXECUTOR_FAILURES + if (!schedulerBlacklist.contains(hostname) && + !allocatorBlacklist.contains(hostname)) { + failureTracker.registerFailureOnHost(hostname) + updateAllocationBlacklistedNodes(hostname) + } + case _ => + failureTracker.registerExecutorFailure() + } + } + + private def updateAllocationBlacklistedNodes(hostname: String): Unit = { + val failuresOnHost = failureTracker.numFailuresOnHost(hostname) + if (failuresOnHost > maxFailuresPerHost) { + logInfo(s"blacklisting $hostname as YARN allocation failed $failuresOnHost times") + allocatorBlacklist.put( + hostname, + failureTracker.clock.getTimeMillis() + blacklistTimeoutMillis) + refreshBlacklistedNodes() + } + } + + def setSchedulerBlacklistedNodes(schedulerBlacklistedNodesWithExpiry: Set[String]): Unit = { + this.schedulerBlacklist = schedulerBlacklistedNodesWithExpiry + refreshBlacklistedNodes() + } + + def isAllNodeBlacklisted: Boolean = currentBlacklistedYarnNodes.size >= numClusterNodes + + private def refreshBlacklistedNodes(): Unit = { + removeExpiredYarnBlacklistedNodes() + val allBlacklistedNodes = schedulerBlacklist ++ allocatorBlacklist.keySet + synchronizeBlacklistedNodeWithYarn(allBlacklistedNodes) + } + + private def synchronizeBlacklistedNodeWithYarn(nodesToBlacklist: Set[String]): Unit = { + // Update blacklist information to YARN ResourceManager for this application, + // in order to avoid allocating new Containers on the problematic nodes. + val additions = (nodesToBlacklist -- currentBlacklistedYarnNodes).toList.sorted + val removals = (currentBlacklistedYarnNodes -- nodesToBlacklist).toList.sorted + if (additions.nonEmpty) { + logInfo(s"adding nodes to YARN application master's blacklist: $additions") + } + if (removals.nonEmpty) { + logInfo(s"removing nodes from YARN application master's blacklist: $removals") + } + amClient.updateBlacklist(additions.asJava, removals.asJava) + currentBlacklistedYarnNodes = nodesToBlacklist + } + + private def removeExpiredYarnBlacklistedNodes(): Unit = { + val now = failureTracker.clock.getTimeMillis() + allocatorBlacklist.retain { (_, expiryTime) => expiryTime > now } + } +} + +/** + * FailureTracker is responsible for tracking executor failures both for each host separately + * and for all hosts altogether. + */ +private[spark] class FailureTracker( + sparkConf: SparkConf, + val clock: Clock = new SystemClock) extends Logging { + + private val executorFailuresValidityInterval = + sparkConf.get(config.EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) + + // Queue to store the timestamp of failed executors for each host + private val failedExecutorsTimeStampsPerHost = mutable.Map[String, mutable.Queue[Long]]() + + private val failedExecutorsTimeStamps = new mutable.Queue[Long]() + + private def updateAndCountFailures(failedExecutorsWithTimeStamps: mutable.Queue[Long]): Int = { + val endTime = clock.getTimeMillis() + while (executorFailuresValidityInterval > 0 && + failedExecutorsWithTimeStamps.nonEmpty && + failedExecutorsWithTimeStamps.head < endTime - executorFailuresValidityInterval) { + failedExecutorsWithTimeStamps.dequeue() + } + failedExecutorsWithTimeStamps.size + } + + def numFailedExecutors: Int = synchronized { + updateAndCountFailures(failedExecutorsTimeStamps) + } + + def registerFailureOnHost(hostname: String): Unit = synchronized { + val timeMillis = clock.getTimeMillis() + failedExecutorsTimeStamps.enqueue(timeMillis) + val failedExecutorsOnHost = + failedExecutorsTimeStampsPerHost.getOrElse(hostname, { + val failureOnHost = mutable.Queue[Long]() + failedExecutorsTimeStampsPerHost.put(hostname, failureOnHost) + failureOnHost + }) + failedExecutorsOnHost.enqueue(timeMillis) + } + + def registerExecutorFailure(): Unit = synchronized { + val timeMillis = clock.getTimeMillis() + failedExecutorsTimeStamps.enqueue(timeMillis) + } + + def numFailuresOnHost(hostname: String): Int = { + failedExecutorsTimeStampsPerHost.get(hostname).map { failedExecutorsOnHost => + updateAndCountFailures(failedExecutorsOnHost) + }.getOrElse(0) + } + +} + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 17234b120ae13..b59dcf158d87c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -42,23 +42,20 @@ private[spark] class YarnRMClient extends Logging { /** * Registers the application master with the RM. * + * @param driverHost Host name where driver is running. + * @param driverPort Port where driver is listening. * @param conf The Yarn configuration. * @param sparkConf The Spark configuration. * @param uiAddress Address of the SparkUI. * @param uiHistoryAddress Address of the application on the History Server. - * @param securityMgr The security manager. - * @param localResources Map with information about files distributed via YARN's cache. */ def register( - driverUrl: String, - driverRef: RpcEndpointRef, + driverHost: String, + driverPort: Int, conf: YarnConfiguration, sparkConf: SparkConf, uiAddress: Option[String], - uiHistoryAddress: String, - securityMgr: SecurityManager, - localResources: Map[String, LocalResource] - ): YarnAllocator = { + uiHistoryAddress: String): Unit = { amClient = AMRMClient.createAMRMClient() amClient.init(conf) amClient.start() @@ -70,10 +67,19 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(driverRef.address.host, driverRef.address.port, - trackingUrl) + amClient.registerApplicationMaster(driverHost, driverPort, trackingUrl) registered = true } + } + + def createAllocator( + conf: YarnConfiguration, + sparkConf: SparkConf, + driverUrl: String, + driverRef: RpcEndpointRef, + securityMgr: SecurityManager, + localResources: Map[String, LocalResource]): YarnAllocator = { + require(registered, "Must register AM before creating allocator.") new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, localResources, new SparkRackResolver()) } @@ -88,6 +94,9 @@ private[spark] class YarnRMClient extends Logging { if (registered) { amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) } + if (amClient != null) { + amClient.stop() + } } /** Returns the attempt ID. */ diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 8eda6cb1277c5..7250e58b6c49a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -200,7 +200,29 @@ object YarnSparkHadoopUtil { .map(new Path(_).getFileSystem(hadoopConf)) .getOrElse(FileSystem.get(hadoopConf)) - filesystemsToAccess + stagingFS + // Add the list of available namenodes for all namespaces in HDFS federation. + // If ViewFS is enabled, this is skipped as ViewFS already handles delegation tokens for its + // namespaces. + val hadoopFilesystems = if (stagingFS.getScheme == "viewfs") { + Set.empty + } else { + val nameservices = hadoopConf.getTrimmedStrings("dfs.nameservices") + // Retrieving the filesystem for the nameservices where HA is not enabled + val filesystemsWithoutHA = nameservices.flatMap { ns => + Option(hadoopConf.get(s"dfs.namenode.rpc-address.$ns")).map { nameNode => + new Path(s"hdfs://$nameNode").getFileSystem(hadoopConf) + } + } + // Retrieving the filesystem for the nameservices where HA is enabled + val filesystemsWithHA = nameservices.flatMap { ns => + Option(hadoopConf.get(s"dfs.ha.namenodes.$ns")).map { _ => + new Path(s"hdfs://$ns").getFileSystem(hadoopConf) + } + } + (filesystemsWithoutHA ++ filesystemsWithHA).toSet + } + + filesystemsToAccess ++ hadoopFilesystems + stagingFS } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 1a99b3bd57672..129084a86597a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -328,4 +328,10 @@ package object config { CACHED_FILES_TYPES, CACHED_CONF_ARCHIVE) + /* YARN allocator-level blacklisting related config entries. */ + private[spark] val YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED = + ConfigBuilder("spark.yarn.blacklist.executor.launch.blacklisting.enabled") + .booleanConf + .createWithDefault(false) + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala index d4eeb6bbcf886..26a2e5d730218 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -44,6 +44,10 @@ private[yarn] class YARNHadoopDelegationTokenManager( // public for testing val credentialProviders = getCredentialProviders + if (credentialProviders.nonEmpty) { + logDebug("Using the following YARN-specific credential providers: " + + s"${credentialProviders.keys.mkString(", ")}.") + } /** * Writes delegation tokens to creds. Delegation tokens are fetched from all registered diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 06e54a2eaf95a..f1a8df00f9c5b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnAppReport} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkAppHandle @@ -75,13 +75,23 @@ private[spark] class YarnClientSchedulerBackend( val monitorInterval = conf.get(CLIENT_LAUNCH_MONITOR_INTERVAL) assert(client != null && appId.isDefined, "Application has not been submitted yet!") - val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true, - interval = monitorInterval) // blocking + val YarnAppReport(state, _, diags) = client.monitorApplication(appId.get, + returnOnRunning = true, interval = monitorInterval) if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { - throw new SparkException("Yarn application has already ended! " + - "It might have been killed or unable to launch application master.") + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + val genericMessage = "The YARN application has already ended! " + + "It might have been killed or the Application Master may have failed to start. " + + "Check the YARN application logs for more details." + val exceptionMsg = diags match { + case Some(msg) => + logError(genericMessage) + msg + + case None => + genericMessage + } + throw new SparkException(exceptionMsg) } if (state == YarnApplicationState.RUNNING) { logInfo(s"Application ${appId.get} has started running.") @@ -100,8 +110,13 @@ private[spark] class YarnClientSchedulerBackend( override def run() { try { - val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false) - logError(s"Yarn application has already exited with state $state!") + val YarnAppReport(_, state, diags) = + client.monitorApplication(appId.get, logApplicationReport = true) + logError(s"YARN application has exited unexpectedly with state $state! " + + "Check the YARN application logs for more details.") + diags.foreach { err => + logError(s"Diagnostics message: $err") + } allowInterrupt = false sc.stop() } catch { @@ -124,7 +139,7 @@ private[spark] class YarnClientSchedulerBackend( private def asyncMonitorApplication(): MonitorThread = { assert(client != null && appId.isDefined, "Application has not been submitted yet!") val t = new MonitorThread - t.setName("Yarn application state monitor") + t.setName("YARN application state monitor") t.setDaemon(true) t } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index ac67f2196e0a0..b0abcc9149d08 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -36,6 +36,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.launcher._ import org.apache.spark.util.Utils @@ -216,6 +217,14 @@ abstract class BaseYarnClusterSuite props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + // SPARK-24446: make sure special characters in the library path do not break containers. + if (!Utils.isWindows) { + val libPath = """/tmp/does not exist:$PWD/tmp:/tmp/quote":/tmp/ampersand&""" + props.setProperty(AM_LIBRARY_PATH.key, libPath) + props.setProperty(DRIVER_LIBRARY_PATH.key, libPath) + props.setProperty(EXECUTOR_LIBRARY_PATH.key, libPath) + } + yarnCluster.getConfig().asScala.foreach { e => props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala new file mode 100644 index 0000000000000..4f77b9c99dd25 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.util.ManualClock + +class FailureTrackerSuite extends SparkFunSuite with Matchers { + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("failures expire if validity interval is set") { + val sparkConf = new SparkConf() + sparkConf.set(config.EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS, 100L) + + val clock = new ManualClock() + val failureTracker = new FailureTracker(sparkConf, clock) + + clock.setTime(0) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (1) + + clock.setTime(10) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (1) + failureTracker.numFailedExecutors should be (2) + + clock.setTime(20) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(30) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + + clock.setTime(101) + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(231) + failureTracker.numFailuresOnHost("host1") should be (0) + failureTracker.numFailuresOnHost("host2") should be (0) + failureTracker.numFailedExecutors should be (0) + } + + + test("failures never expire if validity interval is not set (-1)") { + val sparkConf = new SparkConf() + + val clock = new ManualClock() + val failureTracker = new FailureTracker(sparkConf, clock) + + clock.setTime(0) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (1) + + clock.setTime(10) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (1) + failureTracker.numFailedExecutors should be (2) + + clock.setTime(20) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(30) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + + clock.setTime(1000) + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + } + +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala new file mode 100644 index 0000000000000..aeac68e6ed330 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.util.Arrays +import java.util.Collections + +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.yarn.config.YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED +import org.apache.spark.internal.config.{BLACKLIST_TIMEOUT_CONF, MAX_FAILED_EXEC_PER_NODE} +import org.apache.spark.util.ManualClock + +class YarnAllocatorBlacklistTrackerSuite extends SparkFunSuite with Matchers + with BeforeAndAfterEach { + + val BLACKLIST_TIMEOUT = 100L + val MAX_FAILED_EXEC_PER_NODE_VALUE = 2 + + var amClientMock: AMRMClient[ContainerRequest] = _ + var yarnBlacklistTracker: YarnAllocatorBlacklistTracker = _ + var failureTracker: FailureTracker = _ + var clock: ManualClock = _ + + override def beforeEach(): Unit = { + val sparkConf = new SparkConf() + sparkConf.set(BLACKLIST_TIMEOUT_CONF, BLACKLIST_TIMEOUT) + sparkConf.set(YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED, true) + sparkConf.set(MAX_FAILED_EXEC_PER_NODE, MAX_FAILED_EXEC_PER_NODE_VALUE) + clock = new ManualClock() + + amClientMock = mock(classOf[AMRMClient[ContainerRequest]]) + failureTracker = new FailureTracker(sparkConf, clock) + yarnBlacklistTracker = + new YarnAllocatorBlacklistTracker(sparkConf, amClientMock, failureTracker) + yarnBlacklistTracker.setNumClusterNodes(4) + super.beforeEach() + } + + test("expiring its own blacklisted nodes") { + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host")) + // host should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host"), Collections.emptyList()) + } + } + + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host")) + // the third failure on the host triggers the blacklisting + verify(amClientMock).updateBlacklist(Arrays.asList("host"), Collections.emptyList()) + + clock.advance(BLACKLIST_TIMEOUT) + + // trigger synchronisation of blacklisted nodes with YARN + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set()) + verify(amClientMock).updateBlacklist(Collections.emptyList(), Arrays.asList("host")) + } + + test("not handling the expiry of scheduler blacklisted nodes") { + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1", "host2"), Collections.emptyList()) + + // advance timer more then host1, host2 expiry time + clock.advance(200L) + + // expired blacklisted nodes (simulating a resource request) + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) + // no change is communicated to YARN regarding the blacklisting + verify(amClientMock).updateBlacklist(Collections.emptyList(), Collections.emptyList()) + } + + test("combining scheduler and allocation blacklist") { + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host1")) + // host1 should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host1"), Collections.emptyList()) + } + } + + // as this is the third failure on host1 the node will be blacklisted + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host1")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1"), Collections.emptyList()) + + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host2", "host3")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host2", "host3"), Collections.emptyList()) + + clock.advance(10L) + + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host3", "host4")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host4"), Arrays.asList("host2")) + } + + test("blacklist all available nodes") { + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2", "host3")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1", "host2", "host3"), Collections.emptyList()) + + clock.advance(60L) + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host4")) + // host4 should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host4"), Collections.emptyList()) + } + } + + // the third failure on the host triggers the blacklisting + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host4")) + + verify(amClientMock).updateBlacklist(Arrays.asList("host4"), Collections.emptyList()) + assert(yarnBlacklistTracker.isAllNodeBlacklisted === true) + } +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 525abb6f2b350..3f783baed110d 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -59,6 +59,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter var rmClient: AMRMClient[ContainerRequest] = _ + var clock: ManualClock = _ + var containerNum = 0 override def beforeEach() { @@ -66,6 +68,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter rmClient = AMRMClient.createAMRMClient() rmClient.init(conf) rmClient.start() + clock = new ManualClock() } override def afterEach() { @@ -101,7 +104,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter appAttemptId, new SecurityManager(sparkConf), Map(), - new MockResolver()) + new MockResolver(), + clock) } def createContainer(host: String): Container = { @@ -332,10 +336,14 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map(), Set("hostA")) verify(mockAmClient).updateBlacklist(Seq("hostA").asJava, Seq[String]().asJava) - handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), Set("hostA", "hostB")) + val blacklistedNodes = Set( + "hostA", + "hostB" + ) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), blacklistedNodes) verify(mockAmClient).updateBlacklist(Seq("hostB").asJava, Seq[String]().asJava) - handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set()) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set.empty) verify(mockAmClient).updateBlacklist(Seq[String]().asJava, Seq("hostA", "hostB").asJava) } @@ -353,8 +361,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter test("window based failure executor counting") { sparkConf.set("spark.yarn.executor.failuresValidityInterval", "100s") val handler = createAllocator(4) - val clock = new ManualClock(0L) - handler.setClock(clock) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index a129be7c06b53..3b78b88de778d 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -265,22 +265,17 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // needed locations. val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.10.6-src.zip", + s"$sparkHome/python/lib/py4j-0.10.7-src.zip", s"$sparkHome/python") val extraEnvVars = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv - val moduleDir = - if (clientMode) { - // In client-mode, .py files added with --py-files are not visible in the driver. - // This is something that the launcher library would have to handle. - tempDir - } else { - val subdir = new File(tempDir, "pyModules") - subdir.mkdir() - subdir - } + val moduleDir = { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } val pyModule = new File(moduleDir, "mod1.py") Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index f21353aa007c8..61c0c43f7c04f 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -21,7 +21,8 @@ import java.io.{File, IOException} import java.nio.charset.StandardCharsets import com.google.common.io.{ByteStreams, Files} -import org.apache.hadoop.io.Text +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers @@ -141,4 +142,66 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging } + test("SPARK-24149: retrieve all namenodes from HDFS") { + val sparkConf = new SparkConf() + val basicFederationConf = new Configuration() + basicFederationConf.set("fs.defaultFS", "hdfs://localhost:8020") + basicFederationConf.set("dfs.nameservices", "ns1,ns2") + basicFederationConf.set("dfs.namenode.rpc-address.ns1", "localhost:8020") + basicFederationConf.set("dfs.namenode.rpc-address.ns2", "localhost:8021") + val basicFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(basicFederationConf), + new Path("hdfs://localhost:8021").getFileSystem(basicFederationConf)) + val basicFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, basicFederationConf) + basicFederationResult should be (basicFederationExpected) + + // when viewfs is enabled, namespaces are handled by it, so we don't need to take care of them + val viewFsConf = new Configuration() + viewFsConf.addResource(basicFederationConf) + viewFsConf.set("fs.defaultFS", "viewfs://clusterX/") + viewFsConf.set("fs.viewfs.mounttable.clusterX.link./home", "hdfs://localhost:8020/") + val viewFsExpected = Set(new Path("viewfs://clusterX/").getFileSystem(viewFsConf)) + YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, viewFsConf) should be (viewFsExpected) + + // invalid config should not throw NullPointerException + val invalidFederationConf = new Configuration() + invalidFederationConf.addResource(basicFederationConf) + invalidFederationConf.unset("dfs.namenode.rpc-address.ns2") + val invalidFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(invalidFederationConf)) + val invalidFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, invalidFederationConf) + invalidFederationResult should be (invalidFederationExpected) + + // no namespaces defined, ie. old case + val noFederationConf = new Configuration() + noFederationConf.set("fs.defaultFS", "hdfs://localhost:8020") + val noFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(noFederationConf)) + val noFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, noFederationConf) + noFederationResult should be (noFederationExpected) + + // federation and HA enabled + val federationAndHAConf = new Configuration() + federationAndHAConf.set("fs.defaultFS", "hdfs://clusterXHA") + federationAndHAConf.set("dfs.nameservices", "clusterXHA,clusterYHA") + federationAndHAConf.set("dfs.ha.namenodes.clusterXHA", "x-nn1,x-nn2") + federationAndHAConf.set("dfs.ha.namenodes.clusterYHA", "y-nn1,y-nn2") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn1", "localhost:8020") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn2", "localhost:8021") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn1", "localhost:8022") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn2", "localhost:8023") + federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterXHA", + "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider") + federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterYHA", + "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider") + + val federationAndHAExpected = Set( + new Path("hdfs://clusterXHA").getFileSystem(federationAndHAConf), + new Path("hdfs://clusterYHA").getFileSystem(federationAndHAConf)) + val federationAndHAResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, federationAndHAConf) + federationAndHAResult should be (federationAndHAExpected) + } } diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index bac154e10ae62..bf3da18c3706e 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" - export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:${PYTHONPATH}" export PYSPARK_PYTHONPATH_SET=1 fi diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 5fa75fe348e68..dc95751bf905c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -398,7 +398,7 @@ hintStatement ; fromClause - : FROM relation (',' relation)* lateralView* + : FROM relation (',' relation)* lateralView* pivotClause? ; aggregation @@ -413,6 +413,10 @@ groupingSet | expression ; +pivotClause + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')' + ; + lateralView : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? ; @@ -535,18 +539,11 @@ expression booleanExpression : NOT booleanExpression #logicalNot | EXISTS '(' query ')' #exists - | predicated #booleanDefault + | valueExpression predicate? #predicated | left=booleanExpression operator=AND right=booleanExpression #logicalBinary | left=booleanExpression operator=OR right=booleanExpression #logicalBinary ; -// workaround for: -// https://github.com/antlr/antlr4/issues/780 -// https://github.com/antlr/antlr4/issues/781 -predicated - : valueExpression predicate? - ; - predicate : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression | NOT? kind=IN '(' expression (',' expression)* ')' @@ -588,6 +585,7 @@ primaryExpression | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference | '(' expression ')' #parenthesizedExpression + | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract ; constant @@ -725,7 +723,7 @@ nonReserved | ADD | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER | MAP | ARRAY | STRUCT - | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER + | PIVOT | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP @@ -735,6 +733,7 @@ nonReserved | VIEW | REPLACE | IF | POSITION + | EXTRACT | NO | DATA | START | TRANSACTION | COMMIT | ROLLBACK | IGNORE | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION @@ -745,7 +744,7 @@ nonReserved | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH | ASC | DESC | LIMIT | RENAME | SETS - | AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE + | AT | NULLS | OVERWRITE | ALL | ANY | ALTER | AS | BETWEEN | BY | CREATE | DELETE | DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN @@ -760,6 +759,7 @@ FROM: 'FROM'; ADD: 'ADD'; AS: 'AS'; ALL: 'ALL'; +ANY: 'ANY'; DISTINCT: 'DISTINCT'; WHERE: 'WHERE'; GROUP: 'GROUP'; @@ -805,6 +805,7 @@ RIGHT: 'RIGHT'; FULL: 'FULL'; NATURAL: 'NATURAL'; ON: 'ON'; +PIVOT: 'PIVOT'; LATERAL: 'LATERAL'; WINDOW: 'WINDOW'; OVER: 'OVER'; @@ -872,6 +873,7 @@ TRAILING: 'TRAILING'; IF: 'IF'; POSITION: 'POSITION'; +EXTRACT: 'EXTRACT'; EQ : '=' | '=='; NSEQ: '<=>'; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java new file mode 100644 index 0000000000000..05879902a4ed9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +/** + * Contains all the Utils methods used in the masking expressions. + */ +public class MaskExpressionsUtils { + static final int UNMASKED_VAL = -1; + + /** + * Returns the masking character for {@param c} or {@param c} is it should not be masked. + * @param c the character to transform + * @param maskedUpperChar the character to use instead of a uppercase letter + * @param maskedLowerChar the character to use instead of a lowercase letter + * @param maskedDigitChar the character to use instead of a digit + * @param maskedOtherChar the character to use instead of a any other character + * @return masking character for {@param c} + */ + public static int transformChar( + final int c, + int maskedUpperChar, + int maskedLowerChar, + int maskedDigitChar, + int maskedOtherChar) { + switch(Character.getType(c)) { + case Character.UPPERCASE_LETTER: + if(maskedUpperChar != UNMASKED_VAL) { + return maskedUpperChar; + } + break; + + case Character.LOWERCASE_LETTER: + if(maskedLowerChar != UNMASKED_VAL) { + return maskedLowerChar; + } + break; + + case Character.DECIMAL_DIGIT_NUMBER: + if(maskedDigitChar != UNMASKED_VAL) { + return maskedDigitChar; + } + break; + + default: + if(maskedOtherChar != UNMASKED_VAL) { + return maskedOtherChar; + } + break; + } + + return c; + } + + /** + * Returns the replacement char to use according to the {@param rep} specified by the user and + * the {@param def} default. + */ + public static int getReplacementChar(String rep, int def) { + if (rep != null && rep.length() > 0) { + return rep.codePointAt(0); + } + return def; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d5d934bc91cab..cf2a5ed2e27f9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -83,7 +83,7 @@ public static long calculateSizeOfUnderlyingByteArray(long numFields, int elemen private long elementOffset; private long getElementOffset(int ordinal, int elementSize) { - return elementOffset + ordinal * elementSize; + return elementOffset + ordinal * (long)elementSize; } public Object getBaseObject() { return baseObject; } @@ -414,7 +414,7 @@ public byte[] toByteArray() { public short[] toShortArray() { short[] values = new short[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2); + baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2L); return values; } @@ -422,7 +422,7 @@ public short[] toShortArray() { public int[] toIntArray() { int[] values = new int[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -430,7 +430,7 @@ public int[] toIntArray() { public long[] toLongArray() { long[] values = new long[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8L); return values; } @@ -438,7 +438,7 @@ public long[] toLongArray() { public float[] toFloatArray() { float[] values = new float[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -446,14 +446,14 @@ public float[] toFloatArray() { public double[] toDoubleArray() { double[] values = new double[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8L); return values; } - private static UnsafeArrayData fromPrimitiveArray( + public static UnsafeArrayData fromPrimitiveArray( Object arr, int offset, int length, int elementSize) { final long headerInBytes = calculateHeaderPortionInBytes(length); - final long valueRegionInBytes = elementSize * length; + final long valueRegionInBytes = (long)elementSize * length; final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; if (totalSizeInLongs > Integer.MAX_VALUE / 8) { throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + @@ -463,14 +463,27 @@ private static UnsafeArrayData fromPrimitiveArray( final long[] data = new long[(int)totalSizeInLongs]; Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); - Platform.copyMemory(arr, offset, data, - Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + if (arr != null) { + Platform.copyMemory(arr, offset, data, + Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + } UnsafeArrayData result = new UnsafeArrayData(); result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); return result; } + public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) { + return fromPrimitiveArray(null, offset, length, elementSize); + } + + public static boolean shouldUseGenericArrayData(int elementSize, int length) { + final long headerInBytes = calculateHeaderPortionInBytes(length); + final long valueRegionInBytes = (long)elementSize * length; + final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; + return totalSizeInLongs > Integer.MAX_VALUE / 8; + } + public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) { return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 29a1411241cf6..469b0e60cc9a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -62,6 +62,8 @@ */ public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable { + public static final int WORD_SIZE = 8; + ////////////////////////////////////////////////////////////////////////////// // Static methods ////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java index 905e6820ce6e2..c823de4810f2b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java @@ -41,7 +41,7 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB @Override public UnsafeRow appendRow(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) { - final long recordLength = 8 + klen + vlen + 8; + final long recordLength = 8L + klen + vlen + 8; // if run out of max supported rows or page size, return null if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) { return null; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java index d224332d8a6c9..023ec139652c5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java @@ -21,6 +21,9 @@ import java.io.Reader; import javax.xml.namespace.QName; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPath; import javax.xml.xpath.XPathConstants; import javax.xml.xpath.XPathExpression; @@ -37,9 +40,15 @@ * This is based on Hive's UDFXPathUtil implementation. */ public class UDFXPathUtil { + public static final String SAX_FEATURE_PREFIX = "http://xml.org/sax/features/"; + public static final String EXTERNAL_GENERAL_ENTITIES_FEATURE = "external-general-entities"; + public static final String EXTERNAL_PARAMETER_ENTITIES_FEATURE = "external-parameter-entities"; + private DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); + private DocumentBuilder builder = null; private XPath xpath = XPathFactory.newInstance().newXPath(); private ReusableStringReader reader = new ReusableStringReader(); private InputSource inputSource = new InputSource(reader); + private XPathExpression expression = null; private String oldPath = null; @@ -65,14 +74,31 @@ public Object eval(String xml, String path, QName qname) throws XPathExpressionE return null; } + if (builder == null){ + try { + initializeDocumentBuilderFactory(); + builder = dbf.newDocumentBuilder(); + } catch (ParserConfigurationException e) { + throw new RuntimeException( + "Error instantiating DocumentBuilder, cannot build xml parser", e); + } + } + reader.set(xml); try { - return expression.evaluate(inputSource, qname); + return expression.evaluate(builder.parse(inputSource), qname); } catch (XPathExpressionException e) { throw new RuntimeException("Invalid XML document: " + e.getMessage() + "\n" + xml, e); + } catch (Exception e) { + throw new RuntimeException("Error loading expression '" + oldPath + "'", e); } } + private void initializeDocumentBuilderFactory() throws ParserConfigurationException { + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_GENERAL_ENTITIES_FEATURE, false); + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_PARAMETER_ENTITIES_FEATURE, false); + } + public Boolean evalBoolean(String xml, String path) throws XPathExpressionException { return (Boolean) eval(xml, path, XPathConstants.BOOLEAN); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index ccdb6bc5d4b7c..7b02317b8538f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -68,10 +68,10 @@ import org.apache.spark.sql.types._ */ @Experimental @InterfaceStability.Evolving -@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + - "(Int, String, etc) and Product types (case classes) are supported by importing " + - "spark.implicits._ Support for serializing other types will be added in future " + - "releases.") +@implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " + + "store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " + + "classes) are supported by importing spark.implicits._ Support for serializing other types " + + "will be added in future releases.") trait Encoder[T] extends Serializable { /** Returns the schema of encoding this type of object as a Row. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 0b95a8821b05a..b47ec0b72c638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -132,7 +132,7 @@ object Encoders { * - primitive types: boolean, int, double, etc. * - boxed types: Boolean, Integer, Double, etc. * - String - * - java.math.BigDecimal + * - java.math.BigDecimal, java.math.BigInteger * - time related: java.sql.Date, java.sql.Timestamp * - collection types: only array and java.util.List currently, map support is in progress * - nested java bean. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 474ec592201d9..6f5fbdd79e668 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -170,6 +170,9 @@ object CatalystTypeConverters { convertedIterable += elementConverter.toCatalyst(item) } new GenericArrayData(convertedIterable.toArray) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to an array of ${elementType.catalogString}") } } @@ -206,6 +209,10 @@ object CatalystTypeConverters { scalaValue match { case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction) case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + "cannot be converted to a map type with " + + s"key type (${keyType.catalogString}) and value type (${valueType.catalogString})") } } @@ -252,6 +259,9 @@ object CatalystTypeConverters { idx += 1 } new GenericInternalRow(ar) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to ${structType.catalogString}") } override def toScala(row: InternalRow): Row = { @@ -276,6 +286,10 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 + case chr: Char => UTF8String.fromString(chr.toString) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to the string type") } override def toScala(catalystValue: UTF8String): String = if (catalystValue == null) null else catalystValue.toString @@ -309,6 +323,9 @@ object CatalystTypeConverters { case d: JavaBigDecimal => Decimal(d) case d: JavaBigInteger => Decimal(d) case d: Decimal => d + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to ${dataType.catalogString}") } decimal.toPrecision(dataType.precision, dataType.scale) } @@ -414,6 +431,12 @@ object CatalystTypeConverters { map, (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) + case (keys: Array[_], values: Array[_]) => + // case for mapdata with duplicate keys + new ArrayBasedMapData( + new GenericArrayData(keys.map(convertToCatalyst)), + new GenericArrayData(values.map(convertToCatalyst)) + ) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 818cc2fb1e8a8..4543bba8f6ed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -798,7 +798,12 @@ object ScalaReflection extends ScalaReflection { * Whether the fields of the given type is defined entirely by its constructor parameters. */ def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { - tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + tpe.dealias match { + // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type. + case t if t <:< localTypeOf[Option[_]] => definedByConstructorParams(t.typeArgs.head) + case _ => tpe.dealias <:< localTypeOf[Product] || + tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + } } private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch", @@ -846,6 +851,19 @@ object ScalaReflection extends ScalaReflection { } } + def javaBoxedType(dt: DataType): Class[_] = dt match { + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayType] + case _: MapType => classOf[MapType] + case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) + case ObjectType(cls) => cls + case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { if (arguments != Nil) { arguments.map(e => dataTypeJavaClass(e.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e821e96522f7c..36f14ccdc6989 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -275,9 +278,9 @@ class Analyzer( case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.copy(aggregations = assignAliases(g.aggregations)) - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) - if child.resolved && hasUnresolvedAlias(groupByExprs) => - Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) + case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) + if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => + Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) @@ -504,9 +507,15 @@ class Analyzer( object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) - | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) + || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) + || !p.pivotColumn.resolved => p + case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + // Check all aggregate expressions. + aggregates.foreach(checkValidAggregateExpression) + // Group-by expressions coming from SQL are implicit and need to be deduced. + val groupByExprs = groupByExprsOpt.getOrElse( + (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) val singleAgg = aggregates.size == 1 def outputName(value: Literal, aggregate: Expression): String = { val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) @@ -568,16 +577,25 @@ class Analyzer( // TODO: Don't construct the physical container until after analysis. case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } - if (filteredAggregate.fastEquals(aggregate)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$aggregate'") - } Alias(filteredAggregate, outputName(value, aggregate))() } } Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) } } + + // Support any aggregate expression that can appear in an Aggregate plan except Pandas UDF. + // TODO: Support Pandas UDF. + private def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK and leave the argument check to CheckAnalysis. + case expr: PythonUDF if PythonUDF.isGroupedAggPandasUDF(expr) => + failAnalysis("Pandas UDF aggregate expressions are currently not supported in pivot.") + case e: Attribute => + failAnalysis( + s"Aggregate expression required for pivot, but '${e.sql}' " + + s"did not appear in any aggregate function.") + case e => e.children.foreach(checkValidAggregateExpression) + } } /** @@ -661,13 +679,13 @@ class Analyzer( try { catalog.lookupRelation(tableIdentWithDb) } catch { - case _: NoSuchTableException => - u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}") + case e: NoSuchTableException => + u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}", e) // If the database is defined and that database is not found, throw an AnalysisException. // Note that if the database is not defined, it is possible we are looking up a temp view. case e: NoSuchDatabaseException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " + - s"database ${e.db} doesn't exist.") + s"database ${e.db} doesn't exist.", e) } } @@ -723,6 +741,10 @@ class Analyzer( if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + case oldVersion @ FlatMapGroupsInPandas(_, _, output, _) + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(output = output.map(_.newInstance()))) + case oldVersion: Generate if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) @@ -1110,7 +1132,8 @@ class Analyzer( case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, child) + if (!s.resolved || s.missingInput.nonEmpty) && child.resolved => val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) val ordering = newOrder.map(_.asInstanceOf[SortOrder]) if (child.output == newChild.output) { @@ -1121,7 +1144,7 @@ class Analyzer( Project(child.output, newSort) } - case f @ Filter(cond, child) if !f.resolved && child.resolved => + case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved => val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) if (child.output == newChild.output) { f.copy(condition = newCond.head) @@ -1132,10 +1155,17 @@ class Analyzer( } } + /** + * This method tries to resolve expressions and find missing attributes recursively. Specially, + * when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved + * attributes which are missed from child output. This method tries to find the missing + * attributes out and add into the projection. + */ private def resolveExprsAndAddMissingAttrs( exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { - if (exprs.forall(_.resolved)) { - // All given expressions are resolved, no need to continue anymore. + // Missing attributes can be unresolved attributes or resolved attributes which are not in + // the output attributes of the plan. + if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { (exprs, plan) } else { plan match { @@ -1146,15 +1176,19 @@ class Analyzer( (newExprs, AnalysisBarrier(newChild)) case p: Project => + // Resolving expressions against current plan. val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) + // Recursively resolving expressions on the child of current plan. val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + // If some attributes used by expressions are resolvable only on the rewritten child + // plan, we need to add them into original projection. + val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) (newExprs, Project(p.projectList ++ missingAttrs, newChild)) case a @ Aggregate(groupExprs, aggExprs, child) => val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { // All the missing attributes are grouping expressions, valid case. (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) @@ -1189,16 +1223,46 @@ class Analyzer( * only performs simple existence check according to the function identifier to quickly identify * undefined functions without triggering relation resolution, which may incur potentially * expensive partition/schema discovery process in some cases. - * + * In order to avoid duplicate external functions lookup, the external function identifier will + * store in the local hash set externalFunctionNameSet. * @see [[ResolveFunctions]] * @see https://issues.apache.org/jira/browse/SPARK-19737 */ object LookupFunctions extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { - case f: UnresolvedFunction if !catalog.functionExists(f.name) => - withPosition(f) { - throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName) - } + override def apply(plan: LogicalPlan): LogicalPlan = { + val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() + plan.transformAllExpressions { + case f: UnresolvedFunction + if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f + case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f + case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) => + externalFunctionNameSet.add(normalizeFuncName(f.name)) + f + case f: UnresolvedFunction => + withPosition(f) { + throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase), + f.name.funcName) + } + } + } + + def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = { + val funcName = if (conf.caseSensitiveAnalysis) { + name.funcName + } else { + name.funcName.toLowerCase(Locale.ROOT) + } + + val databaseName = name.database match { + case Some(a) => formatDatabaseName(a) + case None => catalog.getCurrentDatabase + } + + FunctionIdentifier(funcName, Some(databaseName)) + } + + protected def formatDatabaseName(name: String): String = { + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } } @@ -1474,7 +1538,11 @@ class Analyzer( // Try resolving the ordering as though it is in the aggregate clause. try { - val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + // If a sort order is unresolved, containing references not in aggregate, or containing + // `AggregateExpression`, we need to push down it to the underlying aggregate operator. + val unresolvedSortOrders = sortOrder.filter { s => + !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + } val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) @@ -1724,15 +1792,16 @@ class Analyzer( * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions * it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for * all regular expressions. - * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s. - * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts - * it into the plan tree. + * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s + * and [[WindowFunctionType]]s. + * 3. For every distinct [[WindowSpecDefinition]] and [[WindowFunctionType]], creates a + * [[Window]] operator and inserts it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { - private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = - projectList.exists(hasWindowFunction) + private def hasWindowFunction(exprs: Seq[Expression]): Boolean = + exprs.exists(hasWindowFunction) - private def hasWindowFunction(expr: NamedExpression): Boolean = { + private def hasWindowFunction(expr: Expression): Boolean = { expr.find { case window: WindowExpression => true case _ => false @@ -1815,6 +1884,10 @@ class Analyzer( seenWindowAggregates += newAgg WindowExpression(newAgg, spec) + case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) => + failAnalysis("It is not allowed to use a window function inside an aggregate " + + "function. Please use the inner window function in a sub-query.") + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => @@ -1882,7 +1955,7 @@ class Analyzer( s"Please file a bug report with this error message, stack trace, and the query.") } else { val spec = distinctWindowSpec.head - (spec.partitionSpec, spec.orderSpec) + (spec.partitionSpec, spec.orderSpec, WindowFunctionType.functionType(expr)) } }.toSeq @@ -1890,7 +1963,7 @@ class Analyzer( // setting this to the child of the next Window operator. val windowOps = groupedWindowExpressions.foldLeft(child) { - case (last, ((partitionSpec, orderSpec), windowExpressions)) => + case (last, ((partitionSpec, orderSpec, _), windowExpressions)) => Window(windowExpressions, partitionSpec, orderSpec, last) } @@ -1903,6 +1976,9 @@ class Analyzer( // "Aggregate with Having clause" will be triggered. def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case Filter(condition, _) if hasWindowFunction(condition) => + failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses") + // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 90bda2a72ad82..af256b98b34f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ @@ -112,12 +113,19 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis("An offset window function can only be evaluated in an ordered " + s"row-based window frame with a single offset: $w") + case _ @ WindowExpression(_: PythonUDF, + WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame)) + if !frame.isUnbounded => + failAnalysis("Only unbounded window frame is supported with Pandas UDFs.") + case w @ WindowExpression(e, s) => // Only allow window functions with an aggregate expression or an offset window - // function. + // function or a Pandas window UDF. e match { case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => w + case f: PythonUDF if PythonUDF.isWindowPandasUDF(f) => + w case _ => failAnalysis(s"Expression '$e' not supported within a window function.") } @@ -154,7 +162,7 @@ trait CheckAnalysis extends PredicateHelper { case Aggregate(groupingExprs, aggregateExprs, child) => def isAggregateExpression(expr: Expression) = { - expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr) + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr) } def checkValidAggregateExpression(expr: Expression): Unit = expr match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c41f16c61d7a2..1d9e470c9ba83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -299,6 +299,15 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[RegrCount]("regr_count"), + expression[RegrSXX]("regr_sxx"), + expression[RegrSYY]("regr_syy"), + expression[RegrAvgX]("regr_avgx"), + expression[RegrAvgY]("regr_avgy"), + expression[RegrSXY]("regr_sxy"), + expression[RegrSlope]("regr_slope"), + expression[RegrR2]("regr_r2"), + expression[RegrIntercept]("regr_intercept"), // string functions expression[Ascii]("ascii"), @@ -401,20 +410,44 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArraysOverlap]("arrays_overlap"), + expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), + expression[ArraySort]("array_sort"), + expression[ArrayUnion]("array_union"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), + expression[MapFromArrays]("map_from_arrays"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), + expression[MapEntries]("map_entries"), + expression[MapFromEntries]("map_from_entries"), + expression[MapConcat]("map_concat"), expression[Size]("size"), + expression[Slice]("slice"), + expression[Size]("cardinality"), + expression[ArraysZip]("arrays_zip"), expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), expression[Concat]("concat"), + expression[Flatten]("flatten"), + expression[Sequence]("sequence"), + expression[ArrayRepeat]("array_repeat"), + expression[ArrayRemove]("array_remove"), + expression[ArrayDistinct]("array_distinct"), CreateStruct.registryEntry, + // mask functions + expression[Mask]("mask"), + expression[MaskFirstN]("mask_first_n"), + expression[MaskLastN]("mask_last_n"), + expression[MaskShowFirstN]("mask_show_first_n"), + expression[MaskShowLastN]("mask_show_last_n"), + expression[MaskHash]("mask_hash"), + // misc functions expression[AssertTrue]("assert_true"), expression[Crc32]("crc32"), @@ -474,6 +507,7 @@ object FunctionRegistry { // json expression[StructsToJson]("to_json"), expression[JsonToStructs]("from_json"), + expression[SchemaOfJson]("schema_of_json"), // cast expression[Cast]("cast"), @@ -531,7 +565,9 @@ object FunctionRegistry { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - val validParametersCount = constructors.map(_.getParameterCount).distinct.sorted + val validParametersCount = constructors + .filter(_.getParameterTypes.forall(_ == classOf[Expression])) + .map(_.getParameterCount).distinct.sorted val expectedNumberOfParameters = if (validParametersCount.length == 1) { validParametersCount.head.toString } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index f2df3e132629f..71ed75454cd4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -103,7 +103,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas castedExpr.eval() } catch { case NonFatal(ex) => - table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") + table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex) } }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cfcbd8db559a3..e8331c90ea0f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -59,7 +59,7 @@ object TypeCoercion { IfCoercion :: StackCoercion :: Division :: - ImplicitTypeCasts :: + new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -102,17 +102,7 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) - - case _ => None + case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType) } /** Promotes all the way to StringType. */ @@ -158,6 +148,42 @@ object TypeCoercion { case (l, r) => None } + private def findTypeForComplex( + t1: DataType, + t2: DataType, + findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { + case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => + findTypeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => + findTypeFunc(kt1, kt2).flatMap { kt => + findTypeFunc(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2) + } + } + case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => + val resolver = SQLConf.get.resolver + fields1.zip(fields2).foldLeft(Option(new StructType())) { + case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => + findTypeFunc(field1.dataType, field2.dataType).map { + dt => struct.add(field1.name, dt, field1.nullable || field2.nullable) + } + case _ => None + } + case _ => None + } + + /** + * The method finds a common type for data types that differ only in nullable, containsNull + * and valueContainsNull flags. If the input types are too different, None is returned. + */ + def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = { + if (t1 == t2) { + Some(t1) + } else { + findTypeForComplex(t1, t2, findCommonTypeDifferentOnlyInNullFlags) + } + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -168,11 +194,7 @@ object TypeCoercion { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) - case _ => None - }) + .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) } /** @@ -208,12 +230,7 @@ object TypeCoercion { t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2) - .map(ArrayType(_, containsNull1 || containsNull2)) - case _ => None - }) + .orElse(findTypeForComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo)) } def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { @@ -528,6 +545,29 @@ object TypeCoercion { case None => c } + case aj @ ArrayJoin(arr, d, nr) if !ArrayType(StringType).acceptsType(arr.dataType) && + ArrayType.acceptsType(arr.dataType) => + val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull + ImplicitTypeCasts.implicitCast(arr, ArrayType(StringType, containsNull)) match { + case Some(castedArr) => ArrayJoin(castedArr, d, nr) + case None => aj + } + + case s @ Sequence(_, _, _, timeZoneId) if !haveSameType(s.coercibleChildren) => + val types = s.coercibleChildren.map(_.dataType) + findWiderCommonType(types) match { + case Some(widerDataType) => s.castChildrenTo(widerDataType) + case None => s + } + + case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && + !haveSameType(children) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType))) + case None => m + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { @@ -632,8 +672,8 @@ object TypeCoercion { object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes) + case c: CaseWhen if c.childrenResolved && !c.areInputTypesForMergingEqual => + val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => var changed = false val newBranches = c.branches.map { case (condition, value) => @@ -665,10 +705,10 @@ object TypeCoercion { plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. - case i @ If(pred, left, right) if left.dataType != right.dataType => + case i @ If(pred, left, right) if !i.areInputTypesForMergingEqual => findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + val newLeft = if (left.dataType.sameType(widestType)) left else Cast(left, widestType) + val newRight = if (right.dataType.sameType(widestType)) right else Cast(right, widestType) If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => @@ -768,12 +808,33 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - object ImplicitTypeCasts extends TypeCoercionRule { + class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + + private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) + override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + // Special rules for `from/to_utc_timestamp`. These 2 functions assume the input timestamp + // string is in a specific timezone, so the string itself should not contain timezone. + // TODO: We should move the type coercion logic to expressions instead of a central + // place to put all the rules. + case e: FromUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + + case e: ToUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonType(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { @@ -790,7 +851,7 @@ object TypeCoercion { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. - implicitCast(in, expected).getOrElse(in) + ImplicitTypeCasts.implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) @@ -806,6 +867,9 @@ object TypeCoercion { } e.withNewChildren(children) } + } + + object ImplicitTypeCasts { /** * Given an expected data type, try to cast the expression and return the cast expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index ff9d6d7a7dded..f68df5d29b545 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode /** @@ -314,8 +315,10 @@ object UnsupportedOperationChecker { case GroupingSets(_, _, child, _) if child.isStreaming => throwError("GroupingSets is not supported on streaming DataFrames/Datasets") - case GlobalLimit(_, _) | LocalLimit(_, _) if subPlan.children.forall(_.isStreaming) => - throwError("Limits are not supported on streaming DataFrames/Datasets") + case GlobalLimit(_, _) | LocalLimit(_, _) + if subPlan.children.forall(_.isStreaming) && outputMode == InternalOutputModes.Update => + throwError("Limits are not supported on streaming DataFrames/Datasets in Update " + + "output mode") case Sort(_, _, _) if !containsCompleteData(subPlan) => throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + @@ -345,8 +348,20 @@ object UnsupportedOperationChecker { plan.foreachUp { implicit subPlan => subPlan match { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | - _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias) => + _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | + _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => + case Repartition(1, false, _) => + case node: Aggregate => + val aboveSinglePartitionCoalesce = node.find { + case Repartition(1, false, _) => true + case _ => false + }.isDefined + + if (!aboveSinglePartitionCoalesce) { + throwError(s"In continuous processing mode, coalesce(1) must be called before " + + s"aggregate operation ${node.nodeName}.") + } case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index 7731336d247db..354a3fa0602a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -41,6 +41,11 @@ package object analysis { def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg, t.origin.line, t.origin.startPosition) } + + /** Fails the analysis at the point where a specific tree node was parsed. */ + def failAnalysis(msg: String, cause: Throwable): Nothing = { + throw new AnalysisException(msg, t.origin.line, t.origin.startPosition, cause = Some(cause)) + } } /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 45b4f013620c1..1a145c24d78cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.catalog -import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} +import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchPartitionException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ListenerBus /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -31,10 +30,13 @@ import org.apache.spark.util.ListenerBus * * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ -abstract class ExternalCatalog - extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { +trait ExternalCatalog { import CatalogTypes.TablePartitionSpec + // -------------------------------------------------------------------------- + // Utils + // -------------------------------------------------------------------------- + protected def requireDbExists(db: String): Unit = { if (!databaseExists(db)) { throw new NoSuchDatabaseException(db) @@ -63,22 +65,9 @@ abstract class ExternalCatalog // Databases // -------------------------------------------------------------------------- - final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { - val db = dbDefinition.name - postToAll(CreateDatabasePreEvent(db)) - doCreateDatabase(dbDefinition, ignoreIfExists) - postToAll(CreateDatabaseEvent(db)) - } + def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - - final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { - postToAll(DropDatabasePreEvent(db)) - doDropDatabase(db, ignoreIfNotExists, cascade) - postToAll(DropDatabaseEvent(db)) - } - - protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** * Alter a database whose name matches the one specified in `dbDefinition`, @@ -87,14 +76,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterDatabase(dbDefinition: CatalogDatabase): Unit = { - val db = dbDefinition.name - postToAll(AlterDatabasePreEvent(db)) - doAlterDatabase(dbDefinition) - postToAll(AlterDatabaseEvent(db)) - } - - protected def doAlterDatabase(dbDefinition: CatalogDatabase): Unit + def alterDatabase(dbDefinition: CatalogDatabase): Unit def getDatabase(db: String): CatalogDatabase @@ -110,41 +92,15 @@ abstract class ExternalCatalog // Tables // -------------------------------------------------------------------------- - final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - val tableDefinitionWithVersion = - tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) - postToAll(CreateTablePreEvent(db, name)) - doCreateTable(tableDefinitionWithVersion, ignoreIfExists) - postToAll(CreateTableEvent(db, name)) - } - - protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - - final def dropTable( - db: String, - table: String, - ignoreIfNotExists: Boolean, - purge: Boolean): Unit = { - postToAll(DropTablePreEvent(db, table)) - doDropTable(db, table, ignoreIfNotExists, purge) - postToAll(DropTableEvent(db, table)) - } + def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - protected def doDropTable( + def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit - final def renameTable(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameTablePreEvent(db, oldName, newName)) - doRenameTable(db, oldName, newName) - postToAll(RenameTableEvent(db, oldName, newName)) - } - - protected def doRenameTable(db: String, oldName: String, newName: String): Unit + def renameTable(db: String, oldName: String, newName: String): Unit /** * Alter a table whose database and name match the ones specified in `tableDefinition`, assuming @@ -154,15 +110,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterTable(tableDefinition: CatalogTable): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) - doAlterTable(tableDefinition) - postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) - } - - protected def doAlterTable(tableDefinition: CatalogTable): Unit + def alterTable(tableDefinition: CatalogTable): Unit /** * Alter the data schema of a table identified by the provided database and table name. The new @@ -173,22 +121,10 @@ abstract class ExternalCatalog * @param table Name of table to alter schema for * @param newDataSchema Updated data schema to be used for the table. */ - final def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) - doAlterTableDataSchema(db, table, newDataSchema) - postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) - } - - protected def doAlterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit + def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - final def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) - doAlterTableStats(db, table, stats) - postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) - } - - protected def doAlterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit + def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit def getTable(db: String, table: String): CatalogTable @@ -340,49 +276,17 @@ abstract class ExternalCatalog // Functions // -------------------------------------------------------------------------- - final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(CreateFunctionPreEvent(db, name)) - doCreateFunction(db, funcDefinition) - postToAll(CreateFunctionEvent(db, name)) - } + def createFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit + def dropFunction(db: String, funcName: String): Unit - final def dropFunction(db: String, funcName: String): Unit = { - postToAll(DropFunctionPreEvent(db, funcName)) - doDropFunction(db, funcName) - postToAll(DropFunctionEvent(db, funcName)) - } + def alterFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doDropFunction(db: String, funcName: String): Unit - - final def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(AlterFunctionPreEvent(db, name)) - doAlterFunction(db, funcDefinition) - postToAll(AlterFunctionEvent(db, name)) - } - - protected def doAlterFunction(db: String, funcDefinition: CatalogFunction): Unit - - final def renameFunction(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameFunctionPreEvent(db, oldName, newName)) - doRenameFunction(db, oldName, newName) - postToAll(RenameFunctionEvent(db, oldName, newName)) - } - - protected def doRenameFunction(db: String, oldName: String, newName: String): Unit + def renameFunction(db: String, oldName: String, newName: String): Unit def getFunction(db: String, funcName: String): CatalogFunction def functionExists(db: String, funcName: String): Boolean def listFunctions(db: String, pattern: String): Seq[String] - - override protected def doPostEvent( - listener: ExternalCatalogEventListener, - event: ExternalCatalogEvent): Unit = { - listener.onEvent(event) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala new file mode 100644 index 0000000000000..2f009be5816fa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ListenerBus + +/** + * Wraps an ExternalCatalog to provide listener events. + */ +class ExternalCatalogWithListener(delegate: ExternalCatalog) + extends ExternalCatalog + with ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { + import CatalogTypes.TablePartitionSpec + + def unwrapped: ExternalCatalog = delegate + + override protected def doPostEvent( + listener: ExternalCatalogEventListener, + event: ExternalCatalogEvent): Unit = { + listener.onEvent(event) + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val db = dbDefinition.name + postToAll(CreateDatabasePreEvent(db)) + delegate.createDatabase(dbDefinition, ignoreIfExists) + postToAll(CreateDatabaseEvent(db)) + } + + override def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + postToAll(DropDatabasePreEvent(db)) + delegate.dropDatabase(db, ignoreIfNotExists, cascade) + postToAll(DropDatabaseEvent(db)) + } + + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = { + val db = dbDefinition.name + postToAll(AlterDatabasePreEvent(db)) + delegate.alterDatabase(dbDefinition) + postToAll(AlterDatabaseEvent(db)) + } + + override def getDatabase(db: String): CatalogDatabase = { + delegate.getDatabase(db) + } + + override def databaseExists(db: String): Boolean = { + delegate.databaseExists(db) + } + + override def listDatabases(): Seq[String] = { + delegate.listDatabases() + } + + override def listDatabases(pattern: String): Seq[String] = { + delegate.listDatabases(pattern) + } + + override def setCurrentDatabase(db: String): Unit = { + delegate.setCurrentDatabase(db) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + val tableDefinitionWithVersion = + tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) + postToAll(CreateTablePreEvent(db, name)) + delegate.createTable(tableDefinitionWithVersion, ignoreIfExists) + postToAll(CreateTableEvent(db, name)) + } + + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = { + postToAll(DropTablePreEvent(db, table)) + delegate.dropTable(db, table, ignoreIfNotExists, purge) + postToAll(DropTableEvent(db, table)) + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameTablePreEvent(db, oldName, newName)) + delegate.renameTable(db, oldName, newName) + postToAll(RenameTableEvent(db, oldName, newName)) + } + + override def alterTable(tableDefinition: CatalogTable): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) + delegate.alterTable(tableDefinition) + postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) + } + + override def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) + delegate.alterTableDataSchema(db, table, newDataSchema) + postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) + } + + override def alterTableStats( + db: String, + table: String, + stats: Option[CatalogStatistics]): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) + delegate.alterTableStats(db, table, stats) + postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) + } + + override def getTable(db: String, table: String): CatalogTable = { + delegate.getTable(db, table) + } + + override def tableExists(db: String, table: String): Boolean = { + delegate.tableExists(db, table) + } + + override def listTables(db: String): Seq[String] = { + delegate.listTables(db) + } + + override def listTables(db: String, pattern: String): Seq[String] = { + delegate.listTables(db, pattern) + } + + override def loadTable( + db: String, + table: String, + loadPath: String, + isOverwrite: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadTable(db, table, loadPath, isOverwrite, isSrcLocal) + } + + override def loadPartition( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + isOverwrite: Boolean, + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadPartition( + db, table, loadPath, partition, isOverwrite, inheritTableSpecs, isSrcLocal) + } + + override def loadDynamicPartitions( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + replace: Boolean, + numDP: Int): Unit = { + delegate.loadDynamicPartitions(db, table, loadPath, partition, replace, numDP) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = { + delegate.createPartitions(db, table, parts, ignoreIfExists) + } + + override def dropPartitions( + db: String, + table: String, + partSpecs: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean, + purge: Boolean, + retainData: Boolean): Unit = { + delegate.dropPartitions(db, table, partSpecs, ignoreIfNotExists, purge, retainData) + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = { + delegate.renamePartitions(db, table, specs, newSpecs) + } + + override def alterPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Unit = { + delegate.alterPartitions(db, table, parts) + } + + override def getPartition( + db: String, + table: String, + spec: TablePartitionSpec): CatalogTablePartition = { + delegate.getPartition(db, table, spec) + } + + override def getPartitionOption( + db: String, + table: String, + spec: TablePartitionSpec): Option[CatalogTablePartition] = { + delegate.getPartitionOption(db, table, spec) + } + + override def listPartitionNames( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] = { + delegate.listPartitionNames(db, table, partialSpec) + } + + override def listPartitions( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { + delegate.listPartitions(db, table, partialSpec) + } + + override def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { + delegate.listPartitionsByFilter(db, table, predicates, defaultTimeZoneId) + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(CreateFunctionPreEvent(db, name)) + delegate.createFunction(db, funcDefinition) + postToAll(CreateFunctionEvent(db, name)) + } + + override def dropFunction(db: String, funcName: String): Unit = { + postToAll(DropFunctionPreEvent(db, funcName)) + delegate.dropFunction(db, funcName) + postToAll(DropFunctionEvent(db, funcName)) + } + + override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(AlterFunctionPreEvent(db, name)) + delegate.alterFunction(db, funcDefinition) + postToAll(AlterFunctionEvent(db, name)) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameFunctionPreEvent(db, oldName, newName)) + delegate.renameFunction(db, oldName, newName) + postToAll(RenameFunctionEvent(db, oldName, newName)) + } + + override def getFunction(db: String, funcName: String): CatalogFunction = { + delegate.getFunction(db, funcName) + } + + override def functionExists(db: String, funcName: String): Boolean = { + delegate.functionExists(db, funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = { + delegate.listFunctions(db, pattern) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 8eacfa058bd52..741dc46b07382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -98,7 +98,7 @@ class InMemoryCatalog( // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { @@ -119,7 +119,7 @@ class InMemoryCatalog( } } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = synchronized { @@ -152,7 +152,7 @@ class InMemoryCatalog( } } - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { requireDbExists(dbDefinition.name) catalog(dbDefinition.name).db = dbDefinition } @@ -180,7 +180,7 @@ class InMemoryCatalog( // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) @@ -221,7 +221,7 @@ class InMemoryCatalog( } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -264,7 +264,7 @@ class InMemoryCatalog( } } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = synchronized { @@ -294,7 +294,7 @@ class InMemoryCatalog( catalog(db).tables.remove(oldName) } - override def doAlterTable(tableDefinition: CatalogTable): Unit = synchronized { + override def alterTable(tableDefinition: CatalogTable): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -303,7 +303,7 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = synchronized { @@ -313,7 +313,7 @@ class InMemoryCatalog( catalog(db).tables(table).table = origTable.copy(schema = newSchema) } - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = synchronized { @@ -564,24 +564,24 @@ class InMemoryCatalog( // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionNotExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doDropFunction(db: String, funcName: String): Unit = synchronized { + override def dropFunction(db: String, funcName: String): Unit = synchronized { requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override protected def doAlterFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def alterFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c390337c03ff5..b09b81eabf60d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -619,6 +619,7 @@ class SessionCatalog( requireTableExists(TableIdentifier(oldTableName, Some(db))) requireTableNotExists(TableIdentifier(newTableName, Some(db))) validateName(newTableName) + validateNewLocationOfRename(oldName, newName) externalCatalog.renameTable(db, oldTableName, newTableName) } else { if (newName.database.isDefined) { @@ -1192,6 +1193,22 @@ class SessionCatalog( !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } + /** + * Return whether this function has been registered in the function registry of the current + * session. If not existed, return false. + */ + def isRegisteredFunction(name: FunctionIdentifier): Boolean = { + functionRegistry.functionExists(name) + } + + /** + * Returns whether it is a persistent function. If not existed, returns false. + */ + def isPersistentFunction(name: FunctionIdentifier): Boolean = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + databaseExists(db) && externalCatalog.functionExists(db, name.funcName) + } + protected def failFunctionLookup(name: FunctionIdentifier): Nothing = { throw new NoSuchFunctionException( db = name.database.getOrElse(getCurrentDatabase), func = name.funcName) @@ -1366,4 +1383,23 @@ class SessionCatalog( // copy over temporary views tempViews.foreach(kv => target.tempViews.put(kv._1, kv._2)) } + + /** + * Validate the new locatoin before renaming a managed table, which should be non-existent. + */ + private def validateNewLocationOfRename( + oldName: TableIdentifier, + newName: TableIdentifier): Unit = { + val oldTable = getTableMetadata(oldName) + if (oldTable.tableType == CatalogTableType.MANAGED) { + val databaseLocation = + externalCatalog.getDatabase(oldName.database.getOrElse(currentDb)).locationUri + val newTableLocation = new Path(new Path(databaseLocation), formatTableName(newName.table)) + val fs = newTableLocation.getFileSystem(hadoopConf) + if (fs.exists(newTableLocation)) { + throw new AnalysisException(s"Can not rename the managed table('$oldName')" + + s". The associated location('$newTableLocation') already exists.") + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index f3e67dc4e975c..c6105c5526049 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -93,12 +93,16 @@ object CatalogStorageFormat { * @param spec partition spec values indexed by column name * @param storage storage format of the partition * @param parameters some parameters for the partition + * @param createTime creation time of the partition, in milliseconds + * @param lastAccessTime last access time, in milliseconds * @param stats optional statistics (number of rows, total size, etc.) */ case class CatalogTablePartition( spec: CatalogTypes.TablePartitionSpec, storage: CatalogStorageFormat, parameters: Map[String, String] = Map.empty, + createTime: Long = System.currentTimeMillis, + lastAccessTime: Long = -1, stats: Option[CatalogStatistics] = None) { def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { @@ -109,6 +113,8 @@ case class CatalogTablePartition( if (parameters.nonEmpty) { map.put("Partition Parameters", s"{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") } + map.put("Created Time", new Date(createTime).toString) + map.put("Last Access", new Date(lastAccessTime).toString) stats.foreach(s => map.put("Partition Statistics", s.simpleString)) map } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index efb2eba655e15..89e8c998f740d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -149,6 +149,7 @@ package object dsl { } } + def rand(e: Long): Expression = Rand(e) def sum(e: Expression): Expression = Sum(e).toAggregateExpression() def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) def count(e: Expression): Expression = Count(e).toAggregateExpression() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index efc2882f0a3d3..cbea3c017a265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -128,7 +128,7 @@ object ExpressionEncoder { case b: BoundReference if b == originalInputObject => newInputObject }) - if (enc.flat) { + val serializerExpr = if (enc.flat) { newSerializer.head } else { // For non-flat encoder, the input object is not top level anymore after being combined to @@ -146,6 +146,7 @@ object ExpressionEncoder { Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) If(nullCheck, Literal.create(null, struct.dataType), struct) } + Alias(serializerExpr, s"_${index + 1}")() } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 789750fd408f2..3340789398f9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4cc84b27d9eb0..df3ab05e02c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); |$javaType ${ev.value} = ${ev.isNull} ? | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) + ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index d848ba18356d3..7541f527a52a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -30,6 +30,7 @@ package org.apache.spark.sql.catalyst.expressions * by `hashCode`. * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. + * - Elements in [[In]] are reordered by `hashCode`. */ object Canonicalize { def execute(e: Expression): Expression = { @@ -85,6 +86,11 @@ object Canonicalize { case Not(GreaterThanOrEqual(l, r)) => LessThan(l, r) case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) + // order the list in the In operator + // In subqueries contain only one element of type ListQuery. So checking that the length > 1 + // we are not reordering In subqueries. + case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode())) + case _ => e } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 12330bfa55ab9..7971ae602bd37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -133,6 +134,26 @@ object Cast { toPrecedence > 0 && fromPrecedence > toPrecedence } + /** + * Returns true iff we can safely cast the `from` type to `to` type without any truncating or + * precision lose, e.g. int -> long, date -> timestamp. + */ + def canSafeCast(from: AtomicType, to: AtomicType): Boolean = (from, to) match { + case _ if from == to => true + case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true + case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true + case (from, to) if legalNumericPrecedence(from, to) => true + case (DateType, TimestampType) => true + case (_, StringType) => true + case _ => false + } + + private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) + fromPrecedence >= 0 && fromPrecedence < toPrecedence + } + def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match { case (NullType, _) => true case (_, _) if from == to => false @@ -623,8 +644,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - ev.copy(code = eval.code + - castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) + + ev.copy(code = + code""" + ${eval.code} + // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull} + ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} + """) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala new file mode 100644 index 0000000000000..fb25e781e72e4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.codehaus.commons.compiler.CompileException +import org.codehaus.janino.InternalCompilerException + +import org.apache.spark.TaskContext +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * Catches compile error during code generation. + */ +object CodegenError { + def unapply(throwable: Throwable): Option[Exception] = throwable match { + case e: InternalCompilerException => Some(e) + case e: CompileException => Some(e) + case _ => None + } +} + +/** + * Defines values for `SQLConf` config of fallback mode. Use for test only. + */ +object CodegenObjectFactoryMode extends Enumeration { + val FALLBACK, CODEGEN_ONLY, NO_CODEGEN = Value +} + +/** + * A codegen object generator which creates objects with codegen path first. Once any compile + * error happens, it can fallbacks to interpreted implementation. In tests, we can use a SQL config + * `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior. + */ +abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] { + + def createObject(in: IN): OUT = { + // We are allowed to choose codegen-only or no-codegen modes if under tests. + val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE) + val fallbackMode = CodegenObjectFactoryMode.withName(config) + + fallbackMode match { + case CodegenObjectFactoryMode.CODEGEN_ONLY if Utils.isTesting => + createCodeGeneratedObject(in) + case CodegenObjectFactoryMode.NO_CODEGEN if Utils.isTesting => + createInterpretedObject(in) + case _ => + try { + createCodeGeneratedObject(in) + } catch { + case CodegenError(_) => createInterpretedObject(in) + } + } + } + + protected def createCodeGeneratedObject(in: IN): OUT + protected def createInterpretedObject(in: IN): OUT +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 97dff6ae88299..44c5556ff9ccf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] { JavaCode.isNullVariable(isNull), JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) - if (eval.code.nonEmpty) { + if (eval.code.toString.nonEmpty) { // Add `this` in the comment. - eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim) + eval.copy(code = ctx.registerComment(this.toString) + eval.code) } else { eval } @@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too - if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull @@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] { val funcFullName = ctx.addNewFunction(funcName, s""" |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | $setIsNull | return ${eval.value}; |} """.stripMargin) eval.value = JavaCode.variable(newValue, dataType) - eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression { if (nullable) { val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) - ev.copy(code = s""" + ev.copy(code = code""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = FalseLiteral) @@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${midGen.code} ${rightGen.code} @@ -697,6 +695,41 @@ abstract class TernaryExpression extends Expression { } } +/** + * A trait resolving nullable, containsNull, valueContainsNull flags of the output date type. + * This logic is usually utilized by expressions combining data from multiple child expressions + * of non-primitive types (e.g. [[CaseWhen]]). + */ +trait ComplexTypeMergingExpression extends Expression { + + /** + * A collection of data types used for resolution the output type of the expression. By default, + * data types of all child expressions. The collection must not be empty. + */ + @transient + lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) + + /** + * A method determining whether the input types are equal ignoring nullable, containsNull and + * valueContainsNull flags and thus convenient for resolution of the final data type. + */ + def areInputTypesForMergingEqual: Boolean = { + inputTypesForMerging.length <= 1 || inputTypesForMerging.sliding(2, 1).forall { + case Seq(dt1, dt2) => dt1.sameType(dt2) + } + } + + override def dataType: DataType = { + require( + inputTypesForMerging.nonEmpty, + "The collection of input data types must not be empty.") + require( + areInputTypesForMergingEqual, + "All input types must be the same except nullable, containsNull, valueContainsNull flags.") + inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) + } +} + /** * Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages * and Hive function wrappers. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 6d69d69b1c802..55a5bd380859e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -87,12 +87,11 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** * Helper functions for creating an [[InterpretedUnsafeProjection]]. */ -object InterpretedUnsafeProjection extends UnsafeProjectionCreator { - +object InterpretedUnsafeProjection { /** * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. */ - override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + def createProjection(exprs: Seq[Expression]): UnsafeProjection = { // We need to make sure that we do not reuse stateful expressions. val cleanedExpressions = exprs.map(_.transform { case s: Stateful => s.freshCopy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index ad1e7bdb31987..f1da592a76845 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType} /** @@ -38,6 +39,7 @@ import org.apache.spark.sql.types.{DataType, LongType} puts the partition ID in the upper 31 bits, and the lower 33 bits represent the record number within each partition. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. + The function is non-deterministic because its result depends on partition IDs. """) case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { @@ -71,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 3cd73682188bc..6493f09100577 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -108,7 +108,32 @@ abstract class UnsafeProjection extends Projection { override def apply(row: InternalRow): UnsafeRow } -trait UnsafeProjectionCreator { +/** + * The factory object for `UnsafeProjection`. + */ +object UnsafeProjection + extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = { + GenerateUnsafeProjection.generate(in) + } + + override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = { + InterpretedUnsafeProjection.createProjection(in) + } + + protected def toBoundExprs( + exprs: Seq[Expression], + inputSchema: Seq[Attribute]): Seq[Expression] = { + exprs.map(BindReferences.bindReference(_, inputSchema)) + } + + protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = { + exprs.map(_ transform { + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + } + /** * Returns an UnsafeProjection for given StructType. * @@ -129,10 +154,7 @@ trait UnsafeProjectionCreator { * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { - val unsafeExprs = exprs.map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - createProjection(unsafeExprs) + createObject(toUnsafeExprs(exprs)) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -142,34 +164,24 @@ trait UnsafeProjectionCreator { * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { - create(exprs.map(BindReferences.bindReference(_, inputSchema))) - } - - /** - * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. - */ - protected def createProjection(exprs: Seq[Expression]): UnsafeProjection -} - -object UnsafeProjection extends UnsafeProjectionCreator { - - override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { - GenerateUnsafeProjection.generate(exprs) + create(toBoundExprs(exprs, inputSchema)) } /** * Same as other create()'s but allowing enabling/disabling subexpression elimination. - * TODO: refactor the plumbing and clean this up. + * The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example, + * when fallbacking to interpreted execution, it is not supported. */ def create( exprs: Seq[Expression], inputSchema: Seq[Attribute], subexpressionEliminationEnabled: Boolean): UnsafeProjection = { - val e = exprs.map(BindReferences.bindReference(_, inputSchema)) - .map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) + val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema)) + try { + GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled) + } catch { + case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index efd664dde725a..6530b176968f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -34,10 +34,14 @@ object PythonUDF { e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType) } - def isGroupAggPandasUDF(e: Expression): Boolean = { + def isGroupedAggPandasUDF(e: Expression): Boolean = { e.isInstanceOf[PythonUDF] && e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF } + + // This is currently same as GroupedAggPandasUDF, but we might support new types in the future, + // e.g, N -> N transform. + def isWindowPandasUDF(e: Expression): Boolean = isGroupedAggPandasUDF(e) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index e869258469a97..3e7ca88249737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.DataType /** @@ -1030,7 +1031,7 @@ case class ScalaUDF( """.stripMargin ev.copy(code = - s""" + code""" |$evalCode |${initArgs.mkString("\n")} |$callFunc diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index ff7c98f714905..76a881146a146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ abstract sealed class SortDirection { @@ -147,7 +149,41 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { (!child.isAscending && child.nullOrdering == NullsLast) } - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + private lazy val calcPrefix: Any => Long = child.child.dataType match { + case BooleanType => (raw) => + if (raw.asInstanceOf[Boolean]) 1 else 0 + case DateType | TimestampType | _: IntegralType => (raw) => + raw.asInstanceOf[java.lang.Number].longValue() + case FloatType | DoubleType => (raw) => { + val dVal = raw.asInstanceOf[java.lang.Number].doubleValue() + DoublePrefixComparator.computePrefix(dVal) + } + case StringType => (raw) => + StringPrefixComparator.computePrefix(raw.asInstanceOf[UTF8String]) + case BinaryType => (raw) => + BinaryPrefixComparator.computePrefix(raw.asInstanceOf[Array[Byte]]) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + _.asInstanceOf[Decimal].toUnscaledLong + case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => + val p = Decimal.MAX_LONG_DIGITS + val s = p - (dt.precision - dt.scale) + (raw) => { + val value = raw.asInstanceOf[Decimal] + if (value.changePrecision(p, s)) value.toUnscaledLong else Long.MinValue + } + case dt: DecimalType => (raw) => + DoublePrefixComparator.computePrefix(raw.asInstanceOf[Decimal].toDouble) + case _ => (Any) => 0L + } + + override def eval(input: InternalRow): Any = { + val value = child.child.eval(input) + if (value == null) { + null + } else { + calcPrefix(value) + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childCode = child.child.genCode(ctx) @@ -181,7 +217,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { } ev.copy(code = childCode.code + - s""" + code""" |long ${ev.value} = 0L; |boolean ${ev.isNull} = ${childCode.isNull}; |if (!${childCode.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 787bcaf5e81de..9856b37e53fbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { val idTerm = "partitionId" ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 6c4a3601c1730..84e38a8b2711e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -164,7 +165,7 @@ case class PreciseTimestampConversion( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + - s"""boolean ${ev.isNull} = ${eval.isNull}; + code"""boolean ${ev.isNull} = ${eval.isNull}; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index a45854a3b5146..f1bbbdabb41f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -206,27 +206,15 @@ object ApproximatePercentile { * with limited memory. PercentileDigest is backed by [[QuantileSummaries]]. * * @param summaries underlying probabilistic data structure [[QuantileSummaries]]. - * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the - * underlying quantileSummaries is compressed. */ - class PercentileDigest( - private var summaries: QuantileSummaries, - private var isCompressed: Boolean) { - - // Trigger compression if the QuantileSummaries's buffer length exceeds - // compressThresHoldBufferLength. The buffer length can be get by - // quantileSummaries.sampled.length - private[this] final val compressThresHoldBufferLength: Int = { - // Max buffer length after compression. - val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2 - // A safe upper bound for buffer length before compression - maxBufferLengthAfterCompression * 2 - } + class PercentileDigest(private var summaries: QuantileSummaries) { def this(relativeError: Double) = { - this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true) + this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true)) } + private[sql] def isCompressed: Boolean = summaries.compressed + /** Returns compressed object of [[QuantileSummaries]] */ def quantileSummaries: QuantileSummaries = { if (!isCompressed) compress() @@ -236,14 +224,6 @@ object ApproximatePercentile { /** Insert an observation value into the PercentileDigest data structure. */ def add(value: Double): Unit = { summaries = summaries.insert(value) - // The result of QuantileSummaries.insert is un-compressed - isCompressed = false - - // Currently, QuantileSummaries ignores the construction parameter compressThresHold, - // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here - // to make sure QuantileSummaries doesn't occupy infinite memory. - // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold - if (summaries.sampled.length >= compressThresHoldBufferLength) compress() } /** In-place merges in another PercentileDigest. */ @@ -280,7 +260,6 @@ object ApproximatePercentile { private final def compress(): Unit = { summaries = summaries.compress() - isCompressed = true } } @@ -335,8 +314,8 @@ object ApproximatePercentile { sampled(i) = Stats(value, g, delta) i += 1 } - val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count) - new PercentileDigest(summary, isCompressed = true) + val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count, true) + new PercentileDigest(summary) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 708bdbfc36058..a133bc2361eb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -23,24 +23,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") -case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { - - override def prettyName: String = "avg" - - override def children: Seq[Expression] = child :: Nil +abstract class AverageLike(child: Expression) extends DeclarativeAggregate { override def nullable: Boolean = true - // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") - private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) @@ -62,14 +50,6 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit /* count = */ Literal(0L) ) - override lazy val updateExpressions = Seq( - /* sum = */ - Add( - sum, - Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), - /* count = */ If(IsNull(child), count, count + 1L) - ) - override lazy val mergeExpressions = Seq( /* sum = */ sum.left + sum.right, /* count = */ count.left + count.right @@ -85,4 +65,29 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case _ => Cast(sum, resultType) / Cast(count, resultType) } + + protected def updateExpressionsDef: Seq[Expression] = Seq( + /* sum = */ + Add( + sum, + Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), + /* count = */ If(IsNull(child), count, count + 1L) + ) + + override lazy val updateExpressions = updateExpressionsDef +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") +case class Average(child: Expression) + extends AverageLike(child) with ImplicitCastInputTypes { + + override def prettyName: String = "avg" + + override def children: Seq[Expression] = child :: Nil + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 572d29caf5bc9..6bbb083f1e18e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -67,35 +67,7 @@ abstract class CentralMomentAgg(child: Expression) override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0)) - override val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) - val delta = child - avg - val deltaN = delta / newN - val newAvg = avg + deltaN - val newM2 = m2 + delta * (delta - deltaN) - - val delta2 = delta * delta - val deltaN2 = deltaN * deltaN - val newM3 = if (momentOrder >= 3) { - m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) - } else { - Literal(0.0) - } - val newM4 = if (momentOrder >= 4) { - m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + - delta * (delta * delta2 - deltaN * deltaN2) - } else { - Literal(0.0) - } - - trimHigherOrder(Seq( - If(IsNull(child), n, newN), - If(IsNull(child), avg, newAvg), - If(IsNull(child), m2, newM2), - If(IsNull(child), m3, newM3), - If(IsNull(child), m4, newM4) - )) - } + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef override val mergeExpressions: Seq[Expression] = { @@ -128,6 +100,36 @@ abstract class CentralMomentAgg(child: Expression) trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4)) } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + Literal(1.0) + val delta = child - avg + val deltaN = delta / newN + val newAvg = avg + deltaN + val newM2 = m2 + delta * (delta - deltaN) + + val delta2 = delta * delta + val deltaN2 = deltaN * deltaN + val newM3 = if (momentOrder >= 3) { + m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) + } else { + Literal(0.0) + } + val newM4 = if (momentOrder >= 4) { + m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + + delta * (delta * delta2 - deltaN * deltaN2) + } else { + Literal(0.0) + } + + trimHigherOrder(Seq( + If(IsNull(child), n, newN), + If(IsNull(child), avg, newAvg), + If(IsNull(child), m2, newM2), + If(IsNull(child), m3, newM3), + If(IsNull(child), m4, newM4) + )) + } } // Compute the population standard deviation of a column diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 95a4a0d5af634..3cdef72c1f2c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -22,17 +22,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ /** - * Compute Pearson correlation between two expressions. + * Base class for computing Pearson correlation between two expressions. * When applied on empty data (i.e., count is zero), it returns NULL. * * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") -// scalastyle:on line.size.limit -case class Corr(x: Expression, y: Expression) +abstract class PearsonCorrelation(x: Expression, y: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(x, y) @@ -51,7 +47,26 @@ case class Corr(x: Expression, y: Expression) override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef + + override val mergeExpressions: Seq[Expression] = { + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 + val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 + val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 + + Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) + } + + protected def updateExpressionsDef: Seq[Expression] = { val newN = n + Literal(1.0) val dx = x - xAvg val dxN = dx / newN @@ -73,24 +88,15 @@ case class Corr(x: Expression, y: Expression) If(isNull, yMk, newYMk) ) } +} - override val mergeExpressions: Seq[Expression] = { - - val n1 = n.left - val n2 = n.right - val newN = n1 + n2 - val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) - val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) - val newXAvg = xAvg.left + dxN * n2 - val newYAvg = yAvg.left + dyN * n2 - val newCk = ck.left + ck.right + dx * dyN * n1 * n2 - val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 - val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 - Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) - } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") +// scalastyle:on line.size.limit +case class Corr(x: Expression, y: Expression) + extends PearsonCorrelation(x, y) { override val evaluateExpression: Expression = { If(n === Literal(0.0), Literal.create(null, DoubleType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 1990f2f2f0722..40582d0abd762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,24 +21,16 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = """ - _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. - - _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null. - - _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. - """) -// scalastyle:on line.size.limit -case class Count(children: Seq[Expression]) extends DeclarativeAggregate { - +/** + * Base class for all counting aggregators. + */ +abstract class CountLike extends DeclarativeAggregate { override def nullable: Boolean = false // Return data type. override def dataType: DataType = LongType - private lazy val count = AttributeReference("count", LongType, nullable = false)() + protected lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = count :: Nil @@ -46,6 +38,27 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { /* count = */ Literal(0L) ) + override lazy val mergeExpressions = Seq( + /* count = */ count.left + count.right + ) + + override lazy val evaluateExpression = count + + override def defaultResult: Option[Literal] = Option(Literal(0L)) +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. + + _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null. + + _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. + """) +// scalastyle:on line.size.limit +case class Count(children: Seq[Expression]) extends CountLike { + override lazy val updateExpressions = { val nullableChildren = children.filter(_.nullable) if (nullableChildren.isEmpty) { @@ -58,14 +71,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { ) } } - - override lazy val mergeExpressions = Seq( - /* count = */ count.left + count.right - ) - - override lazy val evaluateExpression = count - - override def defaultResult: Option[Literal] = Option(Literal(0L)) } object Count { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index fc6c34baafdd1..72a7c62b328ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -42,23 +42,7 @@ abstract class Covariance(x: Expression, y: Expression) override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)) - override lazy val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) - val dx = x - xAvg - val dy = y - yAvg - val dyN = dy / newN - val newXAvg = xAvg + dx / newN - val newYAvg = yAvg + dyN - val newCk = ck + dx * (y - newYAvg) - - val isNull = IsNull(x) || IsNull(y) - Seq( - If(isNull, n, newN), - If(isNull, xAvg, newXAvg), - If(isNull, yAvg, newYAvg), - If(isNull, ck, newCk) - ) - } + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef override val mergeExpressions: Seq[Expression] = { @@ -75,6 +59,24 @@ abstract class Covariance(x: Expression, y: Expression) Seq(newN, newXAvg, newYAvg, newCk) } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dx / newN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk) + ) + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala new file mode 100644 index 0000000000000..d8f4505588ff2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{AbstractDataType, DoubleType} + +/** + * Base trait for all regression functions. + */ +trait RegrLike extends AggregateFunction with ImplicitCastInputTypes { + def y: Expression + def x: Expression + + override def children: Seq[Expression] = Seq(y, x) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + + protected def updateIfNotNull(exprs: Seq[Expression]): Seq[Expression] = { + assert(aggBufferAttributes.length == exprs.length) + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + exprs + } else { + exprs.zip(aggBufferAttributes).map { case (e, a) => + If(nullableChildren.map(IsNull).reduce(Or), a, e) + } + } + } +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the number of non-null pairs.", + since = "2.4.0") +case class RegrCount(y: Expression, x: Expression) + extends CountLike with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(Seq(count + 1L)) + + override def prettyName: String = "regr_count" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrSXX(y: Expression, x: Expression) + extends CentralMomentAgg(x) with RegrLike { + + override protected def momentOrder = 2 + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), m2) + } + + override def prettyName: String = "regr_sxx" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrSYY(y: Expression, x: Expression) + extends CentralMomentAgg(y) with RegrLike { + + override protected def momentOrder = 2 + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), m2) + } + + override def prettyName: String = "regr_syy" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of x. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrAvgX(y: Expression, x: Expression) + extends AverageLike(x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override def prettyName: String = "regr_avgx" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of y. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrAvgY(y: Expression, x: Expression) + extends AverageLike(y) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override def prettyName: String = "regr_avgy" +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the covariance of y and x multiplied for the number of items in the dataset. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrSXY(y: Expression, x: Expression) + extends Covariance(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), ck) + } + + override def prettyName: String = "regr_sxy" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the slope of the linear regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrSlope(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), ck / yMk) + } + + override def prettyName: String = "regr_slope" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the coefficient of determination (also called R-squared or goodness of fit) for the regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrR2(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), + If(xMk === Literal(0.0), Literal(1.0), ck * ck / yMk / xMk)) + } + + override def prettyName: String = "regr_r2" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the y-intercept of the linear regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrIntercept(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), + xAvg - (ck / yMk) * yAvg) + } + + override def prettyName: String = "regr_intercept" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d4e322d23b95b..fe91e520169b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -220,30 +221,12 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", - examples = """ - Examples: - > SELECT 3 _FUNC_ 2; - 1.5 - > SELECT 2L _FUNC_ 2L; - 1.0 - """) -// scalastyle:on line.size.limit -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { - - override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) +// Common base trait for Divide and Remainder, since these two classes are almost identical +trait DivModLike extends BinaryArithmetic { - override def symbol: String = "/" - override def decimalMethod: String = "$div" override def nullable: Boolean = true - private lazy val div: (Any, Any) => Any = dataType match { - case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div - } - - override def eval(input: InternalRow): Any = { + final override def eval(input: InternalRow): Any = { val input2 = right.eval(input) if (input2 == null || input2 == 0) { null @@ -252,13 +235,15 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (input1 == null) { null } else { - div(input1, input2) + evalOperation(input1, input2) } } } + def evalOperation(left: Any, right: Any): Any + /** - * Special case handling due to division by 0 => null. + * Special case handling due to division/remainder by 0 => null. */ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) @@ -269,13 +254,13 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic s"${eval2.value} == 0" } val javaType = CodeGenerator.javaType(dataType) - val divide = if (dataType.isInstanceOf[DecimalType]) { + val operation = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -283,10 +268,10 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ${ev.isNull} = true; } else { ${eval1.code} - ${ev.value} = $divide; + ${ev.value} = $operation; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -297,13 +282,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (${eval1.isNull}) { ${ev.isNull} = true; } else { - ${ev.value} = $divide; + ${ev.value} = $operation; } }""") } } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", + examples = """ + Examples: + > SELECT 3 _FUNC_ 2; + 1.5 + > SELECT 2L _FUNC_ 2L; + 1.0 + """) +// scalastyle:on line.size.limit +case class Divide(left: Expression, right: Expression) extends DivModLike { + + override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) + + override def symbol: String = "/" + override def decimalMethod: String = "$div" + + private lazy val div: (Any, Any) => Any = dataType match { + case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div + } + + override def evalOperation(left: Any, right: Any): Any = div(left, right) +} + @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.", examples = """ @@ -313,82 +323,30 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic > SELECT MOD(2, 1.8); 0.2 """) -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +case class Remainder(left: Expression, right: Expression) extends DivModLike { override def inputType: AbstractDataType = NumericType override def symbol: String = "%" override def decimalMethod: String = "remainder" - override def nullable: Boolean = true - private lazy val integral = dataType match { - case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] - case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] + private lazy val mod: (Any, Any) => Any = dataType match { + // special cases to make float/double primitive types faster + case DoubleType => + (left, right) => left.asInstanceOf[Double] % right.asInstanceOf[Double] + case FloatType => + (left, right) => left.asInstanceOf[Float] % right.asInstanceOf[Float] + + // catch-all cases + case i: IntegralType => + val integral = i.integral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) + case i: FractionalType => // should only be DecimalType for now + val integral = i.asIntegral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) } - override def eval(input: InternalRow): Any = { - val input2 = right.eval(input) - if (input2 == null || input2 == 0) { - null - } else { - val input1 = left.eval(input) - if (input1 == null) { - null - } else { - input1 match { - case d: Double => d % input2.asInstanceOf[java.lang.Double] - case f: Float => f % input2.asInstanceOf[java.lang.Float] - case _ => integral.rem(input1, input2) - } - } - } - } - - /** - * Special case handling for x % 0 ==> null. - */ - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval1 = left.genCode(ctx) - val eval2 = right.genCode(ctx) - val isZero = if (dataType.isInstanceOf[DecimalType]) { - s"${eval2.value}.isZero()" - } else { - s"${eval2.value} == 0" - } - val javaType = CodeGenerator.javaType(dataType) - val remainder = if (dataType.isInstanceOf[DecimalType]) { - s"${eval1.value}.$decimalMethod(${eval2.value})" - } else { - s"($javaType)(${eval1.value} $symbol ${eval2.value})" - } - if (!left.nullable && !right.nullable) { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if ($isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - ${ev.value} = $remainder; - }""") - } else { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { - ${ev.isNull} = true; - } else { - ${ev.value} = $remainder; - } - }""") - } - } + override def evalOperation(left: Any, right: Any): Any = mod(left, right) } @ExpressionDescription( @@ -479,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -490,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { $result }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -612,7 +570,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes @@ -687,7 +645,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cf0a91ff00626..838c045d5bcce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -38,10 +38,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types._ import org.apache.spark.util.{ParentClassLoader, Utils} @@ -56,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) +case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue) object ExprCode { def apply(isNull: ExprValue, value: ExprValue): ExprCode = { - ExprCode(code = "", isNull, value) + ExprCode(code = EmptyBlock, isNull, value) } def forNullValue(dataType: DataType): ExprCode = { - ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) + ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { - ExprCode(code = "", isNull = FalseLiteral, value = value) + ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value) } } @@ -329,9 +331,9 @@ class CodegenContext { def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { val value = addMutableState(javaType(dataType), variableName) val code = dataType match { - case StringType => s"$value = $initCode.clone();" - case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" - case _ => s"$value = $initCode;" + case StringType => code"$value = $initCode.clone();" + case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();" + case _ => code"$value = $initCode;" } ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } @@ -730,6 +732,73 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. + * + * @param arrayName name of the array to create + * @param numElements code representing the number of elements the array should contain + * @param elementType data type of the elements in the array + * @param additionalErrorMessage string to include in the error message + */ + def createUnsafeArray( + arrayName: String, + numElements: String, + elementType: DataType, + additionalErrorMessage: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + + s""" + |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | ${elementType.defaultSize}); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $arraySize + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." + + | "$additionalErrorMessage"); + |} + |byte[] $arrayBytes = new byte[(int)$arraySize]; + |UnsafeArrayData $arrayName = new UnsafeArrayData(); + |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + """.stripMargin + } + + /** + * Generates code creating a [[UnsafeArrayData]]. The generated code executes + * a provided fallback when the size of backing array would exceed the array size limit. + * @param arrayName a name of the array to create + * @param numElements a piece of code representing the number of elements the array should contain + * @param elementSize a size of an element in bytes + * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]] + * and getting the backing array as a parameter + * @param fallbackCode a piece of code executed when the array size limit is exceeded + */ + def createUnsafeArrayWithFallback( + arrayName: String, + numElements: String, + elementSize: Int, + bodyCode: String => String, + fallbackCode: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + s""" + |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | $elementSize); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | $fallbackCode + |} else { + | final byte[] $arrayBytes = new byte[(int)$arraySize]; + | UnsafeArrayData $arrayName = new UnsafeArrayData(); + | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + | ${bodyCode(arrayBytes)} + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. @@ -750,6 +819,36 @@ class CodegenContext { } } + /** + * Generates code to do null safe execution when accessing properties of complex + * ArrayData elements. + * + * @param nullElements used to decide whether the ArrayData might contain null or not. + * @param isNull a variable indicating whether the result will be evaluated to null or not. + * @param arrayData a variable name representing the ArrayData. + * @param execute the code that should be executed only if the ArrayData doesn't contain + * any null. + */ + def nullArrayElementsSaveExec( + nullElements: Boolean, + isNull: String, + arrayData: String)( + execute: String): String = { + val i = freshName("idx") + if (nullElements) { + s""" + |for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) { + | $isNull |= $arrayData.isNullAt($i); + |} + |if (!$isNull) { + | $execute + |} + """.stripMargin + } else { + execute + } + } + /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow @@ -988,7 +1087,7 @@ class CodegenContext { val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) e.foreach(localSubExprEliminationExprs.put(_, state)) - eval.code.trim + eval.code.toString } SubExprCodes(codes, localSubExprEliminationExprs.toMap) } @@ -1016,7 +1115,7 @@ class CodegenContext { val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code.trim} + | ${eval.code} | $isNull = ${eval.isNull}; | $value = ${eval.value}; |} @@ -1073,7 +1172,7 @@ class CodegenContext { def registerComment( text: => String, placeholderId: String = "", - force: Boolean = false): String = { + force: Boolean = false): Block = { // By default, disable comments in generated code because computing the comments themselves can // be extremely expensive in certain cases, such as deeply-nested expressions which operate over // inputs with wide schemas. For more details on the performance issues that motivated this @@ -1092,9 +1191,9 @@ class CodegenContext { s"// $text" } placeHolderToComments += (name -> comment) - s"/*$name*/" + code"/*$name*/" } else { - "" + EmptyBlock } } } @@ -1316,7 +1415,7 @@ object CodeGenerator extends Logging { * weak keys/values and thus does not respond to memory pressure. */ private val cache = CacheBuilder.newBuilder() - .maximumSize(100) + .maximumSize(SQLConf.get.codegenCacheMaxEntries) .build( new CacheLoader[CodeAndComment, (GeneratedClass, Int)]() { override def load(code: CodeAndComment): (GeneratedClass, Int) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index a91989e129664..3f4704d287cbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -46,7 +47,7 @@ trait CodegenFallback extends Expression { val placeHolder = ctx.registerComment(this.toString) val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; @@ -55,7 +56,7 @@ trait CodegenFallback extends Expression { ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 01c350e9dbf69..39778661d1c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -22,6 +22,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -71,7 +72,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values) ) val code = - s""" + code""" |final InternalRow $tmpInput = $input; |final Object[] $values = new Object[${schema.length}]; |$allFields @@ -97,7 +98,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ctx, JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), elementType) - val code = s""" + val code = code""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); final Object[] $values = new Object[$numElements]; @@ -124,7 +125,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType) val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType) - val code = s""" + val code = code""" final MapData $tmpInput = $input; ${keyConverter.code} ${valueConverter.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 01b4d6c4529bd..8f2a5a0dce943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -286,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = - s""" + code""" |$rowWriter.reset(); |$evalSubexpr |$writeExpressions @@ -343,7 +344,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | } | | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | return ${eval.value}; | } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 74ff018488863..2f8c853e836ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.lang.{Boolean => JBool} +import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{BooleanType, DataType} /** @@ -114,6 +116,158 @@ object JavaCode { } } +/** + * A trait representing a block of java code. + */ +trait Block extends TreeNode[Block] with JavaCode { + import Block._ + + // Returns java code string for this code block. + override def toString: String = _marginChar match { + case Some(c) => code.stripMargin(c).trim + case _ => code.trim + } + + def length: Int = toString.length + + def nonEmpty: Boolean = toString.nonEmpty + + // The leading prefix that should be stripped from each line. + // By default we strip blanks or control characters followed by '|' from the line. + var _marginChar: Option[Char] = Some('|') + + def stripMargin(c: Char): this.type = { + _marginChar = Some(c) + this + } + + def stripMargin: this.type = { + _marginChar = Some('|') + this + } + + /** + * Apply a map function to each java expression codes present in this java code, and return a new + * java code based on the mapped java expression codes. + */ + def transformExprValues(f: PartialFunction[ExprValue, ExprValue]): this.type = { + var changed = false + + @inline def transform(e: ExprValue): ExprValue = { + val newE = f lift e + if (!newE.isDefined || newE.get.equals(e)) { + e + } else { + changed = true + newE.get + } + } + + def doTransform(arg: Any): AnyRef = arg match { + case e: ExprValue => transform(e) + case Some(value) => Some(doTransform(value)) + case seq: Traversable[_] => seq.map(doTransform) + case other: AnyRef => other + } + + val newArgs = mapProductIterator(doTransform) + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + } + + // Concatenates this block with other block. + def + (other: Block): Block = other match { + case EmptyBlock => this + case _ => code"$this\n$other" + } + + override def verboseString: String = toString +} + +object Block { + + val CODE_BLOCK_BUFFER_LENGTH: Int = 512 + + implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ + _) + + implicit class BlockHelper(val sc: StringContext) extends AnyVal { + def code(args: Any*): Block = { + sc.checkLengths(args) + if (sc.parts.length == 0) { + EmptyBlock + } else { + args.foreach { + case _: ExprValue => + case _: Int | _: Long | _: Float | _: Double | _: String => + case _: Block => + case other => throw new IllegalArgumentException( + s"Can not interpolate ${other.getClass.getName} into code block.") + } + + val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args) + CodeBlock(codeParts, blockInputs) + } + } + } + + // Folds eagerly the literal args into the code parts. + private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[JavaCode]) = { + val codeParts = ArrayBuffer.empty[String] + val blockInputs = ArrayBuffer.empty[JavaCode] + + val strings = parts.iterator + val inputs = args.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + + buf.append(strings.next) + while (strings.hasNext) { + val input = inputs.next + input match { + case _: ExprValue | _: CodeBlock => + codeParts += buf.toString + buf.clear + blockInputs += input.asInstanceOf[JavaCode] + case EmptyBlock => + case _ => + buf.append(input) + } + buf.append(strings.next) + } + codeParts += buf.toString + + (codeParts.toSeq, blockInputs.toSeq) + } +} + +/** + * A block of java code. Including a sequence of code parts and some inputs to this block. + * The actual java code is generated by embedding the inputs into the code parts. Here we keep + * inputs of `JavaCode` instead of simply folding them as a string of code, because we need to + * track expressions (`ExprValue`) in this code block. We need to be able to manipulate the + * expressions later without changing the behavior of this code block in some applications, e.g., + * method splitting. + */ +case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { + override def children: Seq[Block] = + blockInputs.filter(_.isInstanceOf[Block]).asInstanceOf[Seq[Block]] + + override lazy val code: String = { + val strings = codeParts.iterator + val inputs = blockInputs.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + buf.append(StringContext.treatEscapes(strings.next)) + while (strings.hasNext) { + buf.append(inputs.next) + buf.append(StringContext.treatEscapes(strings.next)) + } + buf.toString + } +} + +case object EmptyBlock extends Block with Serializable { + override val code: String = "" + override def children: Seq[Block] = Seq.empty +} + /** * A typed java fragment that must be a valid java expression. */ @@ -123,10 +277,9 @@ trait ExprValue extends JavaCode { } object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code } - /** * A java expression fragment. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c16793bda028e..0f4f4f1601b4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -16,49 +16,110 @@ */ package org.apache.spark.sql.catalyst.expressions -import java.util.Comparator +import java.util.{Comparator, TimeZone} + +import scala.collection.mutable +import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.collection.OpenHashSet + +/** + * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit + * casting. + */ +trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression + with ImplicitCastInputTypes { + + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2)) + case _ => Seq.empty + } + case _ => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) => + TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + + s"been two ${ArrayType.simpleString}s with same element type, but it's " + + s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]") + } + } +} + /** - * Given an array or map, returns its size. Returns -1 if null. + * Given an array or map, returns total number of elements in it. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the size of an array or a map. Returns -1 if null.", + usage = """ + _FUNC_(expr) - Returns the size of an array or a map. + The function returns -1 if its input is null and spark.sql.legacy.sizeOfNull is set to true. + If spark.sql.legacy.sizeOfNull is set to false, the function returns null for null input. + By default, the spark.sql.legacy.sizeOfNull parameter is set to true. + """, examples = """ Examples: > SELECT _FUNC_(array('b', 'd', 'c', 'a')); 4 + > SELECT _FUNC_(map('a', 1, 'b', 2)); + 2 + > SELECT _FUNC_(NULL); + -1 """) case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + val legacySizeOfNull = SQLConf.get.legacySizeOfNull + override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) - override def nullable: Boolean = false + override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable override def eval(input: InternalRow): Any = { val value = child.eval(input) if (value == null) { - -1 + if (legacySizeOfNull) -1 else null } else child.dataType match { case _: ArrayType => value.asInstanceOf[ArrayData].numElements() case _: MapType => value.asInstanceOf[MapData].numElements() + case other => throw new UnsupportedOperationException( + s"The size function doesn't support the operand type ${other.getClass.getCanonicalName}") } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childGen = child.genCode(ctx) - ev.copy(code = s""" + if (legacySizeOfNull) { + val childGen = child.genCode(ctx) + ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : (${childGen.value}).numElements();""", isNull = FalseLiteral) + } else { + defineCodeGen(ctx, ev, c => s"($c).numElements()") + } } } @@ -90,6 +151,172 @@ case class MapKeys(child: Expression) override def prettyName: String = "map_keys" } +@ExpressionDescription( + usage = """ + _FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all + N-th values of input arrays. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4)); + [[1, 2], [2, 3], [3, 4]] + > SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4)); + [[1, 2, 3], [2, 3, 4]] + """, + since = "2.4.0") +case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) + + override def dataType: DataType = ArrayType(mountSchema) + + override def nullable: Boolean = children.exists(_.nullable) + + private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) + + private lazy val arrayElementTypes = arrayTypes.map(_.elementType) + + @transient private lazy val mountSchema: StructType = { + val fields = children.zip(arrayElementTypes).zipWithIndex.map { + case ((expr: NamedExpression, elementType), _) => + StructField(expr.name, elementType, nullable = true) + case ((_, elementType), idx) => + StructField(idx.toString, elementType, nullable = true) + } + StructType(fields) + } + + @transient lazy val numberOfArrays: Int = children.length + + @transient lazy val genericArrayData = classOf[GenericArrayData].getName + + def emptyInputGenCode(ev: ExprCode): ExprCode = { + ev.copy(code""" + |${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]); + |boolean ${ev.isNull} = false; + """.stripMargin) + } + + def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val genericInternalRow = classOf[GenericInternalRow].getName + val arrVals = ctx.freshName("arrVals") + val biggestCardinality = ctx.freshName("biggestCardinality") + + val currentRow = ctx.freshName("currentRow") + val j = ctx.freshName("j") + val i = ctx.freshName("i") + val args = ctx.freshName("args") + + val evals = children.map(_.genCode(ctx)) + val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) => + s""" + |if ($biggestCardinality != -1) { + | ${eval.code} + | if (!${eval.isNull}) { + | $arrVals[$index] = ${eval.value}; + | $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements()); + | } else { + | $biggestCardinality = -1; + | } + |} + """.stripMargin + } + + val splittedGetValuesAndCardinalities = ctx.splitExpressionsWithCurrentInputs( + expressions = getValuesAndCardinalities, + funcName = "getValuesAndCardinalities", + returnType = "int", + makeSplitFunction = body => + s""" + |$body + |return $biggestCardinality; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"), + extraArguments = + ("ArrayData[]", arrVals) :: + ("int", biggestCardinality) :: Nil) + + val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) => + val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i) + s""" + |if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) { + | $currentRow[$idx] = $g; + |} else { + | $currentRow[$idx] = null; + |} + """.stripMargin + } + + val getValueForTypeSplitted = ctx.splitExpressions( + expressions = getValueForType, + funcName = "extractValue", + arguments = + ("int", i) :: + ("Object[]", currentRow) :: + ("ArrayData[]", arrVals) :: Nil) + + val initVariables = s""" + |ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; + |int $biggestCardinality = 0; + |${CodeGenerator.javaType(dataType)} ${ev.value} = null; + """.stripMargin + + ev.copy(code""" + |$initVariables + |$splittedGetValuesAndCardinalities + |boolean ${ev.isNull} = $biggestCardinality == -1; + |if (!${ev.isNull}) { + | Object[] $args = new Object[$biggestCardinality]; + | for (int $i = 0; $i < $biggestCardinality; $i ++) { + | Object[] $currentRow = new Object[$numberOfArrays]; + | $getValueForTypeSplitted + | $args[$i] = new $genericInternalRow($currentRow); + | } + | ${ev.value} = new $genericArrayData($args); + |} + """.stripMargin) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + if (numberOfArrays == 0) { + emptyInputGenCode(ev) + } else { + nonEmptyInputGenCode(ctx, ev) + } + } + + override def eval(input: InternalRow): Any = { + val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData]) + if (inputArrays.contains(null)) { + null + } else { + val biggestCardinality = if (inputArrays.isEmpty) { + 0 + } else { + inputArrays.map(_.numElements()).max + } + + val result = new Array[InternalRow](biggestCardinality) + val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex + + for (i <- 0 until biggestCardinality) { + val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) => + if (i < arr.numElements() && !arr.isNullAt(i)) { + arr.get(i, arrayElementTypes(index)) + } else { + null + } + } + + result(i) = InternalRow.apply(currentLayer: _*) + } + new GenericArrayData(result) + } + } + + override def prettyName: String = "arrays_zip" +} + /** * Returns an unordered array containing the values of the map. */ @@ -119,767 +346,3462 @@ case class MapValues(child: Expression) } /** - * Sorts the input array in ascending / descending order according to the natural ordering of - * the array elements and returns it. + * Returns an unordered array of all entries in the given map. */ -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order according to the natural ordering of the array elements.", + usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.", examples = """ Examples: - > SELECT _FUNC_(array('b', 'd', 'c', 'a'), true); - ["a","b","c","d"] - """) -// scalastyle:on line.size.limit -case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + > SELECT _FUNC_(map(1, 'a', 2, 'b')); + [(1,"a"),(2,"b")] + """, + since = "2.4.0") +case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { - def this(e: Expression) = this(e, Literal(true)) + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) - override def left: Expression = base - override def right: Expression = ascendingOrder - override def dataType: DataType = base.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) + lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] - override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - ascendingOrder match { - case Literal(_: Boolean, BooleanType) => - TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure( - "Sort order in second argument requires a boolean literal.") - } - case ArrayType(dt, _) => - TypeCheckResult.TypeCheckFailure( - s"$prettyName does not support sorting array of type ${dt.simpleString}") - case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + override def dataType: DataType = { + ArrayType( + StructType( + StructField("key", childDataType.keyType, false) :: + StructField("value", childDataType.valueType, childDataType.valueContainsNull) :: + Nil), + false) } - @transient - private lazy val lt: Comparator[Any] = { - val ordering = base.dataType match { - case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + override protected def nullSafeEval(input: Any): Any = { + val childMap = input.asInstanceOf[MapData] + val keys = childMap.keyArray() + val values = childMap.valueArray() + val length = childMap.numElements() + val resultData = new Array[AnyRef](length) + var i = 0; + while (i < length) { + val key = keys.get(i, childDataType.keyType) + val value = values.get(i, childDataType.valueType) + val row = new GenericInternalRow(Array[Any](key, value)) + resultData.update(i, row) + i += 1 } + new GenericArrayData(resultData) + } - new Comparator[Any]() { - override def compare(o1: Any, o2: Any): Int = { - if (o1 == null && o2 == null) { - 0 - } else if (o1 == null) { - -1 - } else if (o2 == null) { - 1 - } else { - ordering.compare(o1, o2) - } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numElements = ctx.freshName("numElements") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + } else { + genCodeForAnyElements(ctx, keys, values, ev.value, numElements) } - } + s""" + |final int $numElements = $c.numElements(); + |final ArrayData $keys = $c.keyArray(); + |final ArrayData $values = $c.valueArray(); + |$code + """.stripMargin + }) } - @transient - private lazy val gt: Comparator[Any] = { - val ordering = base.dataType match { - case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] - } + private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") - new Comparator[Any]() { - override def compare(o1: Any, o2: Any): Int = { - if (o1 == null && o2 == null) { - 0 - } else if (o1 == null) { - 1 - } else if (o2 == null) { - -1 - } else { - -ordering.compare(o1, o2) - } - } + private def getValue(varName: String) = { + CodeGenerator.getValue(varName, childDataType.valueType, "z") + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val unsafeRow = ctx.freshName("unsafeRow") + val unsafeArrayData = ctx.freshName("unsafeArrayData") + val structsOffset = ctx.freshName("structsOffset") + val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val wordSize = UnsafeRow.WORD_SIZE + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 + val structSizeAsLong = structSize + "L" + val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + + val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" + val valueAssignmentChecked = if (childDataType.valueContainsNull) { + s""" + |if ($values.isNullAt(z)) { + | $unsafeRow.setNullAt(1); + |} else { + | $valueAssignment + |} + """.stripMargin + } else { + valueAssignment } + + val assignmentLoop = (byteArray: String) => + s""" + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSizeAsLong; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); + | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); + | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); + | $valueAssignmentChecked + |} + |$arrayData = $unsafeArrayData; + """.stripMargin + + ctx.createUnsafeArrayWithFallback( + unsafeArrayData, + numElements, + structSize + wordSize, + assignmentLoop, + genCodeForAnyElements(ctx, keys, values, arrayData, numElements)) } - override def nullSafeEval(array: Any, ascending: Any): Any = { - val elementType = base.dataType.asInstanceOf[ArrayType].elementType - val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - if (elementType != NullType) { - java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) + private def genCodeForAnyElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val rowClass = classOf[GenericInternalRow].getName + val data = ctx.freshName("internalRowArray") + + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { + s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" + } else { + getValue(values) } - new GenericArrayData(data.asInstanceOf[Array[Any]]) + + s""" + |final Object[] $data = new Object[$numElements]; + |for (int z = 0; z < $numElements; z++) { + | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); + |} + |$arrayData = new $genericArrayClass($data); + """.stripMargin } - override def prettyName: String = "sort_array" + override def prettyName: String = "map_entries" } /** - * Returns a reversed string or an array with reverse order of elements. + * Returns the union of all the given maps. */ @ExpressionDescription( - usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.", + usage = "_FUNC_(map, ...) - Returns the union of all the given maps", examples = """ Examples: - > SELECT _FUNC_('Spark SQL'); - LQS krapS - > SELECT _FUNC_(array(2, 1, 4, 3)); - [3, 4, 1, 2] - """, - since = "1.5.0", - note = "Reverse logic for arrays is available since 2.4.0." -) -case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); + [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] + """, since = "2.4.0") +case class MapConcat(children: Seq[Expression]) extends Expression { - // Input types are utilized by type coercion in ImplicitTypeCasts. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) + override def checkInputDataTypes(): TypeCheckResult = { + var funcName = s"function $prettyName" + if (children.exists(!_.dataType.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure( + s"input to $funcName should all be of type map, but it's " + + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + } + } - override def dataType: DataType = child.dataType + override def dataType: MapType = { + val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption + .getOrElse(MapType(StringType, StringType)) + val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType]) + .exists(_.valueContainsNull) + if (dt.valueContainsNull != valueContainsNull) { + dt.copy(valueContainsNull = valueContainsNull) + } else { + dt + } + } - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + override def nullable: Boolean = children.exists(_.nullable) - override def nullSafeEval(input: Any): Any = input match { - case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) - case s: UTF8String => s.reverse() + override def eval(input: InternalRow): Any = { + val maps = children.map(_.eval(input)) + if (maps.contains(null)) { + return null + } + val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray()) + val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray()) + + val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " + + s"elements due to exceeding the map size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + val finalKeyArray = new Array[AnyRef](numElements.toInt) + val finalValueArray = new Array[AnyRef](numElements.toInt) + var position = 0 + for (i <- keyArrayDatas.indices) { + val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType) + val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType) + Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length) + Array.copy(valueArray, 0, finalValueArray, position, valueArray.length) + position += keyArray.length + } + + new ArrayBasedMapData(new GenericArrayData(finalKeyArray), + new GenericArrayData(finalValueArray)) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, c => dataType match { - case _: StringType => stringCodeGen(ev, c) - case _: ArrayType => arrayCodeGen(ctx, ev, c) - }) - } + val mapCodes = children.map(_.genCode(ctx)) + val keyType = dataType.keyType + val valueType = dataType.valueType + val argsName = ctx.freshName("args") + val hasNullName = ctx.freshName("hasNull") + val mapDataClass = classOf[MapData].getName + val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName + val arrayDataClass = classOf[ArrayData].getName + + val init = + s""" + |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}]; + |boolean ${ev.isNull}, $hasNullName = false; + |$mapDataClass ${ev.value} = null; + """.stripMargin - private def stringCodeGen(ev: ExprCode, childName: String): String = { - s"${ev.value} = ($childName).reverse();" - } + val assignments = mapCodes.zipWithIndex.map { case (m, i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | $argsName[$i] = ${m.value}; + | if (${m.isNull}) { + | $hasNullName = true; + | } + |} + """.stripMargin + } - private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { - val length = ctx.freshName("length") - val javaElementType = CodeGenerator.javaType(elementType) - val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = assignments, + funcName = "getMapConcatInputs", + extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil, + returnType = "boolean", + makeSplitFunction = body => + s""" + |$body + |return $hasNullName; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$hasNullName = $funcCall;").mkString("\n") + ) - val initialization = if (isPrimitiveType) { - s"$childName.copy()" + val idxName = ctx.freshName("idx") + val numElementsName = ctx.freshName("numElems") + val finKeysName = ctx.freshName("finalKeys") + val finValsName = ctx.freshName("finalValues") + + val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) { + genCodeForPrimitiveArrays(ctx, keyType, false) } else { - s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" + genCodeForNonPrimitiveArrays(ctx, keyType) } - val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length - - val swapAssigments = if (isPrimitiveType) { - val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) - val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) - s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); - |boolean isNullAtL = ${ev.value}.isNullAt(l); - |if(!isNullAtK) { - | $javaElementType el = ${getCall("k")}; - | if(!isNullAtL) { - | ${ev.value}.$setFunc(k, ${getCall("l")}); - | } else { - | ${ev.value}.setNullAt(k); - | } - | ${ev.value}.$setFunc(l, el); - |} else if (!isNullAtL) { - | ${ev.value}.$setFunc(k, ${getCall("l")}); - | ${ev.value}.setNullAt(l); - |}""".stripMargin + val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) { + genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) } else { - s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" + genCodeForNonPrimitiveArrays(ctx, valueType) } - s""" - |final int $length = $childName.numElements(); - |${ev.value} = $initialization; - |for(int k = 0; k < $numberOfIterations; k++) { - | int l = $length - k - 1; - | $swapAssigments - |} - """.stripMargin - } + val keyArgsName = ctx.freshName("keyArgs") + val valArgsName = ctx.freshName("valArgs") - override def prettyName: String = "reverse" -} + val mapMerge = + s""" + |${ev.isNull} = $hasNullName; + |if (!${ev.isNull}) { + | $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; + | $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; + | long $numElementsName = 0; + | for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { + | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); + | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); + | $numElementsName += $argsName[$idxName].numElements(); + | } + | if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful attempt to concat maps with " + + | $numElementsName + " elements due to exceeding the map size limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + | } + | $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName, + | (int) $numElementsName); + | $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName, + | (int) $numElementsName); + | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); + |} + """.stripMargin -/** - * Checks if the array (left) has the element (right) - */ -@ExpressionDescription( - usage = "_FUNC_(array, value) - Returns true if the array contains the value.", - examples = """ - Examples: - > SELECT _FUNC_(array(1, 2, 3), 2); - true - """) -case class ArrayContains(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + ev.copy( + code = code""" + |$init + |$codes + |$mapMerge + """.stripMargin) + } - override def dataType: DataType = BooleanType + private def genCodeForPrimitiveArrays( + ctx: CodegenContext, + elementType: DataType, + checkForNull: Boolean): String = { + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - override def inputTypes: Seq[AbstractDataType] = right.dataType match { - case NullType => Seq.empty - case _ => left.dataType match { - case n @ ArrayType(element, _) => Seq(n, element) - case _ => Seq.empty - } - } + val setterCode1 = + s""" + |$arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")} + |);""".stripMargin - override def checkInputDataTypes(): TypeCheckResult = { - if (right.dataType == NullType) { - TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") - } else if (!left.dataType.isInstanceOf[ArrayType] - || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { - TypeCheckResult.TypeCheckFailure( - "Arguments must be an array followed by a value of same type as the array members") + val setterCode = if (checkForNull) { + s""" + |if ($argsName[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + |} else { + | $setterCode1 + |}""".stripMargin } else { - TypeCheckResult.TypeCheckSuccess + setterCode1 } - } - override def nullable: Boolean = { - left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + s""" + |new Object() { + | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) { + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $setterCode + | $counter++; + | } + | } + | return $arrayData; + | } + |}""".stripMargin.stripPrefix("\n") } - override def nullSafeEval(arr: Any, value: Any): Any = { - var hasNull = false - arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == null) { - hasNull = true - } else if (v == value) { - return true - } - ) - if (hasNull) { - null - } else { - false - } - } + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (arr, value) => { - val i = ctx.freshName("i") - val getValue = CodeGenerator.getValue(arr, right.dataType, i) - s""" - for (int $i = 0; $i < $arr.numElements(); $i ++) { - if ($arr.isNullAt($i)) { - ${ev.isNull} = true; - } else if (${ctx.genEqual(right.dataType, value, getValue)}) { - ${ev.isNull} = false; - ${ev.value} = true; - break; - } - } - """ - }) + s""" + |new Object() { + | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {; + | Object[] $arrayData = new Object[$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + | } + |}""".stripMargin.stripPrefix("\n") } - override def prettyName: String = "array_contains" + override def prettyName: String = "map_concat" } /** - * Returns the minimum value in the array. + * Returns a map created from the given array of entries. */ @ExpressionDescription( - usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.", + usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.", examples = """ Examples: - > SELECT _FUNC_(array(1, 20, null, 3)); - 1 - """, since = "2.4.0") -case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b'))); + {1:"a",2:"b"} + """, + since = "2.4.0") +case class MapFromEntries(child: Expression) extends UnaryExpression { - override def nullable: Boolean = true + @transient + private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { + case ArrayType( + StructType(Array( + StructField(_, keyType, keyNullable, _), + StructField(_, valueType, valueNullable, _))), + containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull)) + case _ => None + } - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + private def nullEntries: Boolean = dataTypeDetails.get._3 - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + override def nullable: Boolean = child.nullable || nullEntries - override def checkInputDataTypes(): TypeCheckResult = { - val typeCheckResult = super.checkInputDataTypes() - if (typeCheckResult.isSuccess) { - TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") - } else { - typeCheckResult - } - } + override def dataType: MapType = dataTypeDetails.get._1 - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childGen = child.genCode(ctx) - val javaType = CodeGenerator.javaType(dataType) - val i = ctx.freshName("i") - val item = ExprCode("", - isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), - value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) - ev.copy(code = - s""" - |${childGen.code} - |boolean ${ev.isNull} = true; - |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - |if (!${childGen.isNull}) { - | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { - | ${ctx.reassignIfSmaller(dataType, ev, item)} - | } - |} - """.stripMargin) + override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { + case Some(_) => TypeCheckResult.TypeCheckSuccess + case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + + s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.") } override protected def nullSafeEval(input: Any): Any = { - var min: Any = null - input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => - if (item != null && (min == null || ordering.lt(item, min))) { - min = item + val arrayData = input.asInstanceOf[ArrayData] + val numEntries = arrayData.numElements() + var i = 0 + if(nullEntries) { + while (i < numEntries) { + if (arrayData.isNullAt(i)) return null + i += 1 } - ) - min + } + val keyArray = new Array[AnyRef](numEntries) + val valueArray = new Array[AnyRef](numEntries) + i = 0 + while (i < numEntries) { + val entry = arrayData.getStruct(i, 2) + val key = entry.get(0, dataType.keyType) + if (key == null) { + throw new RuntimeException("The first field from a struct (key) can't be null.") + } + keyArray.update(i, key) + val value = entry.get(1, dataType.valueType) + valueArray.update(i, value) + i += 1 + } + ArrayBasedMapData(keyArray, valueArray) } - override def dataType: DataType = child.dataType match { - case ArrayType(dt, _) => dt - case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numEntries = ctx.freshName("numEntries") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, c, ev.value, numEntries) + } else { + genCodeForAnyElements(ctx, c, ev.value, numEntries) + } + ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) { + s""" + |final int $numEntries = $c.numElements(); + |$code + """.stripMargin + } + }) } - override def prettyName: String = "array_min" -} - -/** - * Returns the maximum value in the array. - */ -@ExpressionDescription( - usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", - examples = """ - Examples: - > SELECT _FUNC_(array(1, 20, null, 3)); - 20 - """, since = "2.4.0") -case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - - override def nullable: Boolean = true - - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) - - override def checkInputDataTypes(): TypeCheckResult = { - val typeCheckResult = super.checkInputDataTypes() - if (typeCheckResult.isSuccess) { - TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + private def genCodeForAssignmentLoop( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String, + keyAssignment: (String, String) => String, + valueAssignment: (String, String) => String): String = { + val entry = ctx.freshName("entry") + val i = ctx.freshName("idx") + + val nullKeyCheck = if (dataTypeDetails.get._2) { + s""" + |if ($entry.isNullAt(0)) { + | throw new RuntimeException("The first field from a struct (key) can't be null."); + |} + """.stripMargin } else { - typeCheckResult + "" } - } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childGen = child.genCode(ctx) - val javaType = CodeGenerator.javaType(dataType) - val i = ctx.freshName("i") - val item = ExprCode("", - isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), - value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) - ev.copy(code = - s""" - |${childGen.code} - |boolean ${ev.isNull} = true; - |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - |if (!${childGen.isNull}) { - | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { - | ${ctx.reassignIfGreater(dataType, ev, item)} - | } - |} - """.stripMargin) + s""" + |for (int $i = 0; $i < $numEntries; $i++) { + | InternalRow $entry = $childVariable.getStruct($i, 2); + | $nullKeyCheck + | ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)} + | ${valueAssignment(entry, i)} + |} + """.stripMargin } - override protected def nullSafeEval(input: Any): Any = { - var max: Any = null - input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => - if (item != null && (max == null || ordering.gt(item, max))) { - max = item + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val byteArraySize = ctx.freshName("byteArraySize") + val keySectionSize = ctx.freshName("keySectionSize") + val valueSectionSize = ctx.freshName("valueSectionSize") + val data = ctx.freshName("byteArray") + val unsafeMapData = ctx.freshName("unsafeMapData") + val keyArrayData = ctx.freshName("keyArrayData") + val valueArrayData = ctx.freshName("valueArrayData") + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val keySize = dataType.keyType.defaultSize + val valueSize = dataType.valueType.defaultSize + val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" + val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" + val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) + + val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);" + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);" + if (dataType.valueContainsNull) { + s""" + |if ($entry.isNullAt(1)) { + | $valueArrayData.setNullAt($idx); + |} else { + | $valueNullUnsafeAssignment + |} + """.stripMargin + } else { + valueNullUnsafeAssignment } + } + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment ) - max + + s""" + |final long $keySectionSize = $kByteSize; + |final long $valueSectionSize = $vByteSize; + |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; + |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)} + |} else { + | final byte[] $data = new byte[(int)$byteArraySize]; + | UnsafeMapData $unsafeMapData = new UnsafeMapData(); + | Platform.putLong($data, $baseOffset, $keySectionSize); + | Platform.putLong($data, ${baseOffset + 8}, $numEntries); + | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries); + | $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize); + | ArrayData $keyArrayData = $unsafeMapData.keyArray(); + | ArrayData $valueArrayData = $unsafeMapData.valueArray(); + | $assignmentLoop + | $mapData = $unsafeMapData; + |} + """.stripMargin } - override def dataType: DataType = child.dataType match { - case ArrayType(dt, _) => dt - case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + private def genCodeForAnyElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val mapDataClass = classOf[ArrayBasedMapData].getName() + + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + if (dataType.valueContainsNull && isValuePrimitive) { + s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;" + } else { + s"$values[$idx] = $value;" + } + } + val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;" + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment) + + s""" + |final Object[] $keys = new Object[$numEntries]; + |final Object[] $values = new Object[$numEntries]; + |$assignmentLoop + |$mapData = $mapDataClass.apply($keys, $values); + """.stripMargin } - override def prettyName: String = "array_max" + override def prettyName: String = "map_from_entries" } /** - * Returns the position of the first occurrence of element in the given array as long. - * Returns 0 if the given value could not be found in the array. Returns null if either of - * the arguments are null - * - * NOTE: that this is not zero based, but 1-based index. The first element in the array has - * index 1. + * Common base class for [[SortArray]] and [[ArraySort]]. */ -@ExpressionDescription( - usage = """ - _FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long. - """, - examples = """ - Examples: - > SELECT _FUNC_(array(3, 2, 1), 1); - 3 - """, - since = "2.4.0") -case class ArrayPosition(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { +trait ArraySortLike extends ExpectsInputTypes { + protected def arrayExpression: Expression - override def dataType: DataType = LongType - override def inputTypes: Seq[AbstractDataType] = - Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + protected def nullOrder: NullOrder - override def nullSafeEval(arr: Any, value: Any): Any = { - arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == value) { - return (i + 1).toLong + @transient + private lazy val lt: Comparator[Any] = { + val ordering = arrayExpression.dataType match { + case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + } + + new Comparator[Any]() { + override def compare(o1: Any, o2: Any): Int = { + if (o1 == null && o2 == null) { + 0 + } else if (o1 == null) { + nullOrder + } else if (o2 == null) { + -nullOrder + } else { + ordering.compare(o1, o2) + } } - ) - 0L + } } - override def prettyName: String = "array_position" + @transient + private lazy val gt: Comparator[Any] = { + val ordering = arrayExpression.dataType match { + case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (arr, value) => { - val pos = ctx.freshName("arrayPosition") - val i = ctx.freshName("i") - val getValue = CodeGenerator.getValue(arr, right.dataType, i) + new Comparator[Any]() { + override def compare(o1: Any, o2: Any): Int = { + if (o1 == null && o2 == null) { + 0 + } else if (o1 == null) { + -nullOrder + } else if (o2 == null) { + nullOrder + } else { + ordering.compare(o2, o1) + } + } + } + } + + def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType + def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull + + def sortEval(array: Any, ascending: Boolean): Any = { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementType != NullType) { + java.util.Arrays.sort(data, if (ascending) lt else gt) + } + new GenericArrayData(data.asInstanceOf[Array[Any]]) + } + + def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = { + val arrayData = classOf[ArrayData].getName + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val array = ctx.freshName("array") + val c = ctx.freshName("c") + if (elementType == NullType) { + s"${ev.value} = $base.copy();" + } else { + val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType) + val sortOrder = ctx.freshName("sortOrder") + val o1 = ctx.freshName("o1") + val o2 = ctx.freshName("o2") + val jt = CodeGenerator.javaType(elementType) + val comp = if (CodeGenerator.isPrimitiveType(elementType)) { + val bt = CodeGenerator.boxedType(elementType) + val v1 = ctx.freshName("v1") + val v2 = ctx.freshName("v2") + s""" + |$jt $v1 = (($bt) $o1).${jt}Value(); + |$jt $v2 = (($bt) $o2).${jt}Value(); + |int $c = ${ctx.genComp(elementType, v1, v2)}; + """.stripMargin + } else { + s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" + } + val nonNullPrimitiveAscendingSort = + if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) { + val javaType = CodeGenerator.javaType(elementType) + val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($order) { + | $javaType[] $array = $base.to${primitiveTypeName}Array(); + | java.util.Arrays.sort($array); + | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array); + |} else + """.stripMargin + } else { + "" + } s""" - |int $pos = 0; - |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { - | $pos = $i + 1; - | break; - | } + |$nonNullPrimitiveAscendingSort + |{ + | Object[] $array = $base.toObjectArray($elementTypeTerm); + | final int $sortOrder = $order ? 1 : -1; + | java.util.Arrays.sort($array, new java.util.Comparator() { + | @Override public int compare(Object $o1, Object $o2) { + | if ($o1 == null && $o2 == null) { + | return 0; + | } else if ($o1 == null) { + | return $sortOrder * $nullOrder; + | } else if ($o2 == null) { + | return -$sortOrder * $nullOrder; + | } + | $comp + | return $sortOrder * $c; + | } + | }); + | ${ev.value} = new $genericArrayData($array); |} - |${ev.value} = (long) $pos; """.stripMargin - }) + } + } + +} + +object ArraySortLike { + type NullOrder = Int + // Least: place null element at the first of the array for ascending order + // Greatest: place null element at the end of the array for ascending order + object NullOrder { + val Least: NullOrder = -1 + val Greatest: NullOrder = 1 } } /** - * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. + * Sorts the input array in ascending / descending order according to the natural ordering of + * the array elements and returns it. */ +// scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, - accesses elements from the last to the first. Returns NULL if the index exceeds the length - of the array. - - _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order + according to the natural ordering of the array elements. Null elements will be placed + at the beginning of the returned array in ascending order or at the end of the returned + array in descending order. """, examples = """ Examples: - > SELECT _FUNC_(array(1, 2, 3), 2); - 2 - > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); - "b" - """, - since = "2.4.0") -case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); + [null,"a","b","c","d"] + """) +// scalastyle:on line.size.limit +case class SortArray(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with ArraySortLike { - override def dataType: DataType = left.dataType match { - case ArrayType(elementType, _) => elementType - case MapType(_, valueType, _) => valueType - } + def this(e: Expression) = this(e, Literal(true)) - override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(ArrayType, MapType), - left.dataType match { - case _: ArrayType => IntegerType - case _: MapType => left.dataType.asInstanceOf[MapType].keyType + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) + + override def arrayExpression: Expression = base + override def nullOrder: NullOrder = NullOrder.Least + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + "Sort order in second argument requires a boolean literal.") } - ) + case ArrayType(dt, _) => + val dtSimple = dt.simpleString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") } - override def nullable: Boolean = true - - override def nullSafeEval(value: Any, ordinal: Any): Any = { - left.dataType match { - case _: ArrayType => - val array = value.asInstanceOf[ArrayData] - val index = ordinal.asInstanceOf[Int] - if (array.numElements() < math.abs(index)) { - null - } else { - val idx = if (index == 0) { - throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") - } else if (index > 0) { - index - 1 - } else { - array.numElements() + index - } - if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { - null - } else { - array.get(idx, dataType) - } - } - case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) - } + override def nullSafeEval(array: Any, ascending: Any): Any = { + sortEval(array, ascending.asInstanceOf[Boolean]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - left.dataType match { - case _: ArrayType => - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val index = ctx.freshName("elementAtIndex") - val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($eval1.isNullAt($index)) { - | ${ev.isNull} = true; - |} else - """.stripMargin - } else { - "" - } - s""" - |int $index = (int) $eval2; - |if ($eval1.numElements() < Math.abs($index)) { - | ${ev.isNull} = true; - |} else { - | if ($index == 0) { - | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); - | } else if ($index > 0) { - | $index--; - | } else { - | $index += $eval1.numElements(); - | } - | $nullCheck - | { - | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; - | } - |} - """.stripMargin - }) - case _: MapType => - doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) - } + nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) } - override def prettyName: String = "element_at" + override def prettyName: String = "sort_array" } + /** - * Concatenates multiple input columns together into a single column. - * The function works with strings, binary and compatible array columns. + * Sorts the input array in ascending order according to the natural ordering of + * the array elements and returns it. */ +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + usage = """ + _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must + be orderable. Null elements will be placed at the end of the returned array. + """, examples = """ Examples: - > SELECT _FUNC_('Spark', 'SQL'); - SparkSQL - > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); - | [1,2,3,4,5,6] - """) -case class Concat(children: Seq[Expression]) extends Expression { + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); + ["a","b","c","d",null] + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike { - private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - val allowedTypes = Seq(StringType, BinaryType, ArrayType) + override def arrayExpression: Expression = child + override def nullOrder: NullOrder = NullOrder.Greatest - override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => TypeCheckResult.TypeCheckSuccess - } else { - val childTypes = children.map(_.dataType) - if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { - return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + - s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) - } - TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") - } + case ArrayType(dt, _) => + val dtSimple = dt.simpleString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") } - override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + override def nullSafeEval(array: Any): Any = { + sortEval(array, true) + } - lazy val javaType: String = CodeGenerator.javaType(dataType) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) + } - override def nullable: Boolean = children.exists(_.nullable) + override def prettyName: String = "array_sort" +} - override def foldable: Boolean = children.forall(_.foldable) +/** + * Returns a reversed string or an array with reverse order of elements. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.", + examples = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + LQS krapS + > SELECT _FUNC_(array(2, 1, 4, 3)); + [3, 4, 1, 2] + """, + since = "1.5.0", + note = "Reverse logic for arrays is available since 2.4.0." +) +case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - override def eval(input: InternalRow): Any = dataType match { - case BinaryType => - val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) - ByteArray.concat(inputs: _*) - case StringType => - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) - case ArrayType(elementType, _) => - val inputs = children.toStream.map(_.eval(input)) - if (inputs.contains(null)) { - null - } else { - val arrayData = inputs.map(_.asInstanceOf[ArrayData]) - val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numberOfElements > MAX_ARRAY_LENGTH) { + // Input types are utilized by type coercion in ImplicitTypeCasts. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) + + override def dataType: DataType = child.dataType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(input: Any): Any = input match { + case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) + case s: UTF8String => s.reverse() + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => dataType match { + case _: StringType => stringCodeGen(ev, c) + case _: ArrayType => arrayCodeGen(ctx, ev, c) + }) + } + + private def stringCodeGen(ev: ExprCode, childName: String): String = { + s"${ev.value} = ($childName).reverse();" + } + + private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { + val length = ctx.freshName("length") + val javaElementType = CodeGenerator.javaType(elementType) + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + + val initialization = if (isPrimitiveType) { + s"$childName.copy()" + } else { + s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" + } + + val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length + + val swapAssigments = if (isPrimitiveType) { + val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) + val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) + s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); + |boolean isNullAtL = ${ev.value}.isNullAt(l); + |if(!isNullAtK) { + | $javaElementType el = ${getCall("k")}; + | if(!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | } else { + | ${ev.value}.setNullAt(k); + | } + | ${ev.value}.$setFunc(l, el); + |} else if (!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | ${ev.value}.setNullAt(l); + |}""".stripMargin + } else { + s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" + } + + s""" + |final int $length = $childName.numElements(); + |${ev.value} = $initialization; + |for(int k = 0; k < $numberOfIterations; k++) { + | int l = $length - k - 1; + | $swapAssigments + |} + """.stripMargin + } + + override def prettyName: String = "reverse" +} + +/** + * Checks if the array (left) has the element (right) + */ +@ExpressionDescription( + usage = "_FUNC_(array, value) - Returns true if the array contains the value.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 2); + true + """) +case class ArrayContains(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = BooleanType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + + override def inputTypes: Seq[AbstractDataType] = right.dataType match { + case NullType => Seq.empty + case _ => left.dataType match { + case n @ ArrayType(element, _) => Seq(n, element) + case _ => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (right.dataType == NullType) { + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + } else if (!left.dataType.isInstanceOf[ArrayType] + || !left.dataType.asInstanceOf[ArrayType].elementType.sameType(right.dataType)) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } else { + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + } + + override def nullSafeEval(arr: Any, value: Any): Any = { + var hasNull = false + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null) { + hasNull = true + } else if (ordering.equiv(v, value)) { + return true + } + ) + if (hasNull) { + null + } else { + false + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, right.dataType, i) + s""" + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${ev.isNull} = true; + } else if (${ctx.genEqual(right.dataType, value, getValue)}) { + ${ev.isNull} = false; + ${ev.value} = true; + break; + } + } + """ + }) + } + + override def prettyName: String = "array_contains" +} + +/** + * Checks if the two arrays contain at least one common element. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5)); + true + """, since = "2.4.0") +// scalastyle:off line.size.limit +case class ArraysOverlap(left: Expression, right: Expression) + extends BinaryArrayExpressionWithImplicitCast { + + override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + case failure => failure + } + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + @transient private lazy val doEvaluation = if (elementTypeSupportEquals) { + fastEval _ + } else { + bruteForceEval _ + } + + override def dataType: DataType = BooleanType + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull + } + + override def nullSafeEval(a1: Any, a2: Any): Any = { + doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData]) + } + + /** + * A fast implementation which puts all the elements from the smaller array in a set + * and then performs a lookup on it for each element of the bigger one. + * This eval mode works only for data types which implements properly the equals method. + */ + private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) { + (arr1, arr2) + } else { + (arr2, arr1) + } + if (smaller.numElements() > 0) { + val smallestSet = new mutable.HashSet[Any] + smaller.foreach(elementType, (_, v) => + if (v == null) { + hasNull = true + } else { + smallestSet += v + }) + bigger.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else if (smallestSet.contains(v1)) { + return true + } + ) + } + if (hasNull) { + null + } else { + false + } + } + + /** + * A slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + if (arr1.numElements() > 0 && arr2.numElements() > 0) { + arr1.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else { + arr2.foreach(elementType, (_, v2) => + if (v2 == null) { + hasNull = true + } else if (ordering.equiv(v1, v2)) { + return true + } + ) + }) + } + if (hasNull) { + null + } else { + false + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (a1, a2) => { + val smaller = ctx.freshName("smallerArray") + val bigger = ctx.freshName("biggerArray") + val comparisonCode = if (elementTypeSupportEquals) { + fastCodegen(ctx, ev, smaller, bigger) + } else { + bruteForceCodegen(ctx, ev, smaller, bigger) + } + s""" + |ArrayData $smaller; + |ArrayData $bigger; + |if ($a1.numElements() > $a2.numElements()) { + | $bigger = $a1; + | $smaller = $a2; + |} else { + | $smaller = $a1; + | $bigger = $a2; + |} + |if ($smaller.numElements() > 0) { + | $comparisonCode + |} + """.stripMargin + }) + } + + /** + * Code generation for a fast implementation which puts all the elements from the smaller array + * in a set and then performs a lookup on it for each element of the bigger one. + * It works only for data types which implements properly the equals method. + */ + private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val javaElementClass = CodeGenerator.boxedType(elementType) + val javaSet = classOf[java.util.HashSet[_]].getName + val set = ctx.freshName("set") + val addToSetFromSmallerCode = nullSafeElementCodegen( + smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;") + val elementIsInSetCode = nullSafeElementCodegen( + bigger, + i, + s""" + |if ($set.contains($getFromBigger)) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>(); + |for (int $i = 0; $i < $smaller.numElements(); $i ++) { + | $addToSetFromSmallerCode + |} + |for (int $i = 0; $i < $bigger.numElements(); $i ++) { + | $elementIsInSetCode + |} + """.stripMargin + } + + /** + * Code generation for a slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val compareValues = nullSafeElementCodegen( + smaller, + j, + s""" + |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + |} + """.stripMargin, + s"${ev.isNull} = true;") + val isInSmaller = nullSafeElementCodegen( + bigger, + i, + s""" + |for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) { + | $compareValues + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) { + | $isInSmaller + |} + """.stripMargin + } + + def nullSafeElementCodegen( + arrayVar: String, + index: String, + code: String, + isNullCode: String): String = { + if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) { + s""" + |if ($arrayVar.isNullAt($index)) { + | $isNullCode + |} else { + | $code + |} + """.stripMargin + } else { + code + } + } + + override def prettyName: String = "arrays_overlap" +} + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def children: Seq[Expression] = Seq(x, start, length) + + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { + val startInt = startVal.asInstanceOf[Int] + val lengthInt = lengthVal.asInstanceOf[Int] + val arr = xVal.asInstanceOf[ArrayData] + val startIndex = if (startInt == 0) { + throw new RuntimeException( + s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") + } else if (startInt < 0) { + startInt + arr.numElements() + } else { + startInt - 1 + } + if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + + "length must be greater than or equal to 0.") + } + // startIndex can be negative if start is negative and its absolute value is greater than the + // number of elements in the array + if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) + } + val data = arr.toSeq[AnyRef](elementType) + new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + | + "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + | + "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |Object[] $values; + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + | $values[$i] = $getValue; + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $resLength = 0; + |} + |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")} + |for (int $i = 0; $i < $resLength; $i ++) { + | if ($inputArray.isNullAt($i + $startIdx)) { + | $values.setNullAt($i); + | } else { + | $values.set$primitiveValueTypeName($i, $getValue); + | } + |} + |${ev.value} = $values; + """.stripMargin + } + } +} + +/** + * Creates a String containing all the elements of the input array separated by the delimiter. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array + using the delimiter and an optional string to replace nulls. If no value is set for + nullReplacement, any null value is filtered.""", + examples = """ + Examples: + > SELECT _FUNC_(array('hello', 'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' ', ','); + hello , world + """, since = "2.4.0") +case class ArrayJoin( + array: Expression, + delimiter: Expression, + nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes { + + def this(array: Expression, delimiter: Expression) = this(array, delimiter, None) + + def this(array: Expression, delimiter: Expression, nullReplacement: Expression) = + this(array, delimiter, Some(nullReplacement)) + + override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { + Seq(ArrayType(StringType), StringType, StringType) + } else { + Seq(ArrayType(StringType), StringType) + } + + override def children: Seq[Expression] = if (nullReplacement.isDefined) { + Seq(array, delimiter, nullReplacement.get) + } else { + Seq(array, delimiter) + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val arrayEval = array.eval(input) + if (arrayEval == null) return null + val delimiterEval = delimiter.eval(input) + if (delimiterEval == null) return null + val nullReplacementEval = nullReplacement.map(_.eval(input)) + if (nullReplacementEval.contains(null)) return null + + val buffer = new UTF8StringBuilder() + var firstItem = true + val nullHandling = nullReplacementEval match { + case Some(rep) => (prependDelimiter: Boolean) => { + if (!prependDelimiter) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(rep.asInstanceOf[UTF8String]) + true + } + case None => (_: Boolean) => false + } + arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => { + if (item == null) { + if (nullHandling(firstItem)) { + firstItem = false + } + } else { + if (!firstItem) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(item.asInstanceOf[UTF8String]) + firstItem = false + } + }) + buffer.build() + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val code = nullReplacement match { + case Some(replacement) => + val replacementGen = replacement.genCode(ctx) + val nullHandling = (buffer: String, delimiter: String, firstItem: String) => { + s""" + |if (!$firstItem) { + | $buffer.append($delimiter); + |} + |$buffer.append(${replacementGen.value}); + |$firstItem = false; + """.stripMargin + } + val execCode = if (replacement.nullable) { + ctx.nullSafeExec(replacement.nullable, replacementGen.isNull) { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + } else { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + s""" + |${replacementGen.code} + |$execCode + """.stripMargin + case None => genCodeForArrayAndDelimiter(ctx, ev, + (_: String, _: String, _: String) => "// nulls are ignored") + } + if (nullable) { + ev.copy( + code""" + |boolean ${ev.isNull} = true; + |UTF8String ${ev.value} = null; + |$code + """.stripMargin) + } else { + ev.copy( + code""" + |UTF8String ${ev.value} = null; + |$code + """.stripMargin, FalseLiteral) + } + } + + private def genCodeForArrayAndDelimiter( + ctx: CodegenContext, + ev: ExprCode, + nullEval: (String, String, String) => String): String = { + val arrayGen = array.genCode(ctx) + val delimiterGen = delimiter.genCode(ctx) + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val i = ctx.freshName("i") + val firstItem = ctx.freshName("firstItem") + val resultCode = + s""" + |$bufferClass $buffer = new $bufferClass(); + |boolean $firstItem = true; + |for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) { + | if (${arrayGen.value}.isNullAt($i)) { + | ${nullEval(buffer, delimiterGen.value, firstItem)} + | } else { + | if (!$firstItem) { + | $buffer.append(${delimiterGen.value}); + | } + | $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)}); + | $firstItem = false; + | } + |} + |${ev.value} = $buffer.build();""".stripMargin + + if (array.nullable || delimiter.nullable) { + arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) { + delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) { + s""" + |${ev.isNull} = false; + |$resultCode""".stripMargin + } + } + } else { + s""" + |${arrayGen.code} + |${delimiterGen.code} + |$resultCode""".stripMargin + } + } + + override def dataType: DataType = StringType + + override def prettyName: String = "array_join" +} + +/** + * Returns the minimum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 1 + """, since = "2.4.0") +case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode(EmptyBlock, + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + code""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfSmaller(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var min: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (min == null || ordering.lt(item, min))) { + min = item + } + ) + min + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_min" +} + +/** + * Returns the maximum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 20 + """, since = "2.4.0") +case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode(EmptyBlock, + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + code""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfGreater(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var max: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (max == null || ordering.gt(item, max))) { + max = item + } + ) + max + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_max" +} + + +/** + * Returns the position of the first occurrence of element in the given array as long. + * Returns 0 if the given value could not be found in the array. Returns null if either of + * the arguments are null + * + * NOTE: that this is not zero based, but 1-based index. The first element in the array has + * index 1. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(3, 2, 1), 1); + 3 + """, + since = "2.4.0") +case class ArrayPosition(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + + override def dataType: DataType = LongType + + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + + override def nullSafeEval(arr: Any, value: Any): Any = { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v != null && ordering.equiv(v, value)) { + return (i + 1).toLong + } + ) + 0L + } + + override def prettyName: String = "array_position" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val pos = ctx.freshName("arrayPosition") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, right.dataType, i) + s""" + |int $pos = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { + | $pos = $i + 1; + | break; + | } + |} + |${ev.value} = (long) $pos; + """.stripMargin + }) + } +} + +/** + * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, + accesses elements from the last to the first. Returns NULL if the index exceeds the length + of the array. + + _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 2); + 2 + > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); + "b" + """, + since = "2.4.0") +case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) + + override def dataType: DataType = left.dataType match { + case ArrayType(elementType, _) => elementType + case MapType(_, valueType, _) => valueType + } + + override def inputTypes: Seq[AbstractDataType] = { + Seq(TypeCollection(ArrayType, MapType), + left.dataType match { + case _: ArrayType => IntegerType + case _: MapType => left.dataType.asInstanceOf[MapType].keyType + case _ => AnyDataType // no match for a wrong 'left' expression type + } + ) + } + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => + TypeUtils.checkForOrderingExpr( + left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName") + case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess + } + } + + override def nullable: Boolean = true + + override def nullSafeEval(value: Any, ordinal: Any): Any = { + left.dataType match { + case _: ArrayType => + val array = value.asInstanceOf[ArrayData] + val index = ordinal.asInstanceOf[Int] + if (array.numElements() < math.abs(index)) { + null + } else { + val idx = if (index == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else if (index > 0) { + index - 1 + } else { + array.numElements() + index + } + if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + null + } else { + array.get(idx, dataType) + } + } + case _: MapType => + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + left.dataType match { + case _: ArrayType => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("elementAtIndex") + val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else + """.stripMargin + } else { + "" + } + s""" + |int $index = (int) $eval2; + |if ($eval1.numElements() < Math.abs($index)) { + | ${ev.isNull} = true; + |} else { + | if ($index == 0) { + | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); + | } else if ($index > 0) { + | $index--; + | } else { + | $index += $eval1.numElements(); + | } + | $nullCheck + | { + | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + | } + |} + """.stripMargin + }) + case _: MapType => + doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) + } + } + + override def prettyName: String = "element_at" +} + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ + Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been ${StringType.simpleString}," + + s" ${BinaryType.simpleString} or ${ArrayType.simpleString}, but it's " + + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") + } + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { + case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { + null + } else { + val arrayData = inputs.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") } - val finalData = new Array[AnyRef](numberOfElements.toInt) - var position = 0 - for(ad <- arrayData) { - val arr = ad.toObjectArray(elementType) - Array.copy(arr, 0, finalData, position, arr.length) - position += arr.length + val finalData = new Array[AnyRef](numberOfElements.toInt) + var position = 0 + for(ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, finalData, position, arr.length) + position += arr.length + } + new GenericArrayData(finalData) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val args = ctx.freshName("args") + + val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" + ${eval.code} + if (!${eval.isNull}) { + $args[$index] = ${eval.value}; + } + """ + } + + val (concatenator, initCode) = dataType match { + case BinaryType => + (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => + ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => + val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrays(ctx, elementType) + } else { + genCodeForNonPrimitiveArrays(ctx, elementType) + } + (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") + } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"$javaType[]", args) :: Nil) + ev.copy(code""" + $initCode + $codes + $javaType ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; + """) + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { + val numElements = ctx.freshName("numElements") + val code = s""" + |long $numElements = 0L; + |for (int z = 0; z < ${children.length}; z++) { + | $numElements += args[z].numElements(); + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + + (code, numElements) + } + + private def nullArgumentProtection() : String = { + if (nullable) { + s""" + |for (int z = 0; z < ${children.length}; z++) { + | if (args[z] == null) return null; + |} + """.stripMargin + } else { + "" + } + } + + private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |new Object() { + | public ArrayData concat($javaType[] args) { + | ${nullArgumentProtection()} + | $numElemCode + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | if (args[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + | } else { + | $arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} + | ); + | } + | $counter++; + | } + | } + | return $arrayData; + | } + |}""".stripMargin.stripPrefix("\n") + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + s""" + |new Object() { + | public ArrayData concat($javaType[] args) { + | ${nullArgumentProtection()} + | $numElemCode + | Object[] $arrayData = new Object[(int)$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + | } + |}""".stripMargin.stripPrefix("\n") + } + + override def toString: String = s"concat(${children.mkString(", ")})" + + override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" +} + +/** + * Transforms an array of arrays into a single array. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", + examples = """ + Examples: + > SELECT _FUNC_(array(array(1, 2), array(3, 4)); + [1,2,3,4] + """, + since = "2.4.0") +case class Flatten(child: Expression) extends UnaryExpression { + + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + + override def nullable: Boolean = child.nullable || childDataType.containsNull + + override def dataType: DataType = childDataType.elementType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(_: ArrayType, _) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + s"The argument should be an array of arrays, " + + s"but '${child.sql}' is of ${child.dataType.simpleString} type." + ) + } + + override def nullSafeEval(child: Any): Any = { + val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType) + + if (elements.contains(null)) { + null + } else { + val arrayData = elements.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val flattenedData = new Array(numberOfElements.toInt) + var position = 0 + for (ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, flattenedData, position, arr.length) + position += arr.length + } + new GenericArrayData(flattenedData) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val code = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) + } + ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code) + }) + } + + private def genCodeForNumberOfElements( + ctx: CodegenContext, + childVariableName: String) : (String, String) = { + val variableName = ctx.freshName("numElements") + val code = s""" + |long $variableName = 0; + |for (int z = 0; z < $childVariableName.numElements(); z++) { + | $variableName += $childVariableName.getArray(z).numElements(); + |} + |if ($variableName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + (code, variableName) + } + + private def genCodeForFlattenOfPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val counter = ctx.freshName("counter") + val tempArrayDataName = ctx.freshName("tempArrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |$numElemCode + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")} + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | if (arr.isNullAt(l)) { + | $tempArrayDataName.setNullAt($counter); + | } else { + | $tempArrayDataName.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue("arr", elementType, "l")} + | ); + | } + | $counter++; + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForFlattenOfNonPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val counter = ctx.freshName("counter") + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; + | $counter++; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + + override def prettyName: String = "flatten" +} + +@ExpressionDescription( + usage = """ + _FUNC_(start, stop, step) - Generates an array of elements from start to stop (inclusive), + incrementing by step. The type of the returned elements is the same as the type of argument + expressions. + + Supported types are: byte, short, integer, long, date, timestamp. + + The start and stop expressions must resolve to the same type. + If start and stop expressions resolve to the 'date' or 'timestamp' type + then the step expression must resolve to the 'interval' type, otherwise to the same type + as the start and stop expressions. + """, + arguments = """ + Arguments: + * start - an expression. The start of the range. + * stop - an expression. The end the range (inclusive). + * step - an optional expression. The step of the range. + By default step is 1 if start is less than or equal to stop, otherwise -1. + For the temporal sequences it's 1 day and -1 day respectively. + If start is greater than stop then the step must be negative, and vice versa. + """, + examples = """ + Examples: + > SELECT _FUNC_(1, 5); + [1, 2, 3, 4, 5] + > SELECT _FUNC_(5, 1); + [5, 4, 3, 2, 1] + > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval 1 month); + [2018-01-01, 2018-02-01, 2018-03-01] + """, + since = "2.4.0" +) +case class Sequence( + start: Expression, + stop: Expression, + stepOpt: Option[Expression], + timeZoneId: Option[String] = None) + extends Expression + with TimeZoneAwareExpression { + + import Sequence._ + + def this(start: Expression, stop: Expression) = + this(start, stop, None, None) + + def this(start: Expression, stop: Expression, step: Expression) = + this(start, stop, Some(step), None) + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Some(timeZoneId)) + + override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false) + + override def checkInputDataTypes(): TypeCheckResult = { + val startType = start.dataType + def stepType = stepOpt.get.dataType + val typesCorrect = + startType.sameType(stop.dataType) && + (startType match { + case TimestampType | DateType => + stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) + case _: IntegralType => + stepOpt.isEmpty || stepType.sameType(startType) + case _ => false + }) + + if (typesCorrect) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"$prettyName only supports integral, timestamp or date types") + } + } + + def coercibleChildren: Seq[Expression] = children.filter(_.dataType != CalendarIntervalType) + + def castChildrenTo(widerType: DataType): Expression = Sequence( + Cast(start, widerType), + Cast(stop, widerType), + stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), + timeZoneId) + + private lazy val impl: SequenceImpl = dataType.elementType match { + case iType: IntegralType => + type T = iType.InternalType + val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) + new IntegralSequenceImpl(iType)(ct, iType.integral) + + case TimestampType => + new TemporalSequenceImpl[Long](LongType, 1, identity, timeZone) + + case DateType => + new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, timeZone) + } + + override def eval(input: InternalRow): Any = { + val startVal = start.eval(input) + if (startVal == null) return null + val stopVal = stop.eval(input) + if (stopVal == null) return null + val stepVal = stepOpt.map(_.eval(input)).getOrElse(impl.defaultStep(startVal, stopVal)) + if (stepVal == null) return null + + ArrayData.toArrayData(impl.eval(startVal, stopVal, stepVal)) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val startGen = start.genCode(ctx) + val stopGen = stop.genCode(ctx) + val stepGen = stepOpt.map(_.genCode(ctx)).getOrElse( + impl.defaultStep.genCode(ctx, startGen, stopGen)) + + val resultType = CodeGenerator.javaType(dataType) + val resultCode = { + val arr = ctx.freshName("arr") + val arrElemType = CodeGenerator.javaType(dataType.elementType) + s""" + |final $arrElemType[] $arr = null; + |${impl.genCode(ctx, startGen.value, stopGen.value, stepGen.value, arr, arrElemType)} + |${ev.value} = UnsafeArrayData.fromPrimitiveArray($arr); + """.stripMargin + } + + if (nullable) { + val nullSafeEval = + startGen.code + ctx.nullSafeExec(start.nullable, startGen.isNull) { + stopGen.code + ctx.nullSafeExec(stop.nullable, stopGen.isNull) { + stepGen.code + ctx.nullSafeExec(stepOpt.exists(_.nullable), stepGen.isNull) { + s""" + |${ev.isNull} = false; + |$resultCode + """.stripMargin + } + } + } + ev.copy(code = + code""" + |boolean ${ev.isNull} = true; + |$resultType ${ev.value} = null; + |$nullSafeEval + """.stripMargin) + + } else { + ev.copy(code = + code""" + |${startGen.code} + |${stopGen.code} + |${stepGen.code} + |$resultType ${ev.value} = null; + |$resultCode + """.stripMargin, + isNull = FalseLiteral) + } + } +} + +object Sequence { + + private type LessThanOrEqualFn = (Any, Any) => Boolean + + private class DefaultStep(lteq: LessThanOrEqualFn, stepType: DataType, one: Any) { + private val negativeOne = UnaryMinus(Literal(one)).eval() + + def apply(start: Any, stop: Any): Any = { + if (lteq(start, stop)) one else negativeOne + } + + def genCode(ctx: CodegenContext, startGen: ExprCode, stopGen: ExprCode): ExprCode = { + val Seq(oneVal, negativeOneVal) = Seq(one, negativeOne).map(Literal(_).genCode(ctx).value) + ExprCode.forNonNullValue(JavaCode.expression( + s"${startGen.value} <= ${stopGen.value} ? $oneVal : $negativeOneVal", + stepType)) + } + } + + private trait SequenceImpl { + def eval(start: Any, stop: Any, step: Any): Any + + def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String + + val defaultStep: DefaultStep + } + + private class IntegralSequenceImpl[T: ClassTag] + (elemType: IntegralType)(implicit num: Integral[T]) extends SequenceImpl { + + override val defaultStep: DefaultStep = new DefaultStep( + (elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + elemType, + num.one) + + override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { + import num._ + + val start = input1.asInstanceOf[T] + val stop = input2.asInstanceOf[T] + val step = input3.asInstanceOf[T] + + var i: Int = getSequenceLength(start, stop, step) + val arr = new Array[T](i) + while (i > 0) { + i -= 1 + arr(i) = start + step * num.fromInt(i) + } + arr + } + + override def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String = { + val i = ctx.freshName("i") + s""" + |${genSequenceLengthCode(ctx, start, stop, step, i)} + |$arr = new $elemType[$i]; + |while ($i > 0) { + | $i--; + | $arr[$i] = ($elemType) ($start + $step * $i); + |} + """.stripMargin + } + } + + private class TemporalSequenceImpl[T: ClassTag] + (dt: IntegralType, scale: Long, fromLong: Long => T, timeZone: TimeZone) + (implicit num: Integral[T]) extends SequenceImpl { + + override val defaultStep: DefaultStep = new DefaultStep( + (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + CalendarIntervalType, + new CalendarInterval(0, MICROS_PER_DAY)) + + private val backedSequenceImpl = new IntegralSequenceImpl[T](dt) + private val microsPerMonth = 28 * CalendarInterval.MICROS_PER_DAY + + override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { + val start = input1.asInstanceOf[T] + val stop = input2.asInstanceOf[T] + val step = input3.asInstanceOf[CalendarInterval] + val stepMonths = step.months + val stepMicros = step.microseconds + + if (stepMonths == 0) { + backedSequenceImpl.eval(start, stop, fromLong(stepMicros / scale)) + + } else { + // To estimate the resulted array length we need to make assumptions + // about a month length in microseconds + val intervalStepInMicros = stepMicros + stepMonths * microsPerMonth + val startMicros: Long = num.toLong(start) * scale + val stopMicros: Long = num.toLong(stop) * scale + val maxEstimatedArrayLength = + getSequenceLength(startMicros, stopMicros, intervalStepInMicros) + + val stepSign = if (stopMicros > startMicros) +1 else -1 + val exclusiveItem = stopMicros + stepSign + val arr = new Array[T](maxEstimatedArrayLength) + var t = startMicros + var i = 0 + + while (t < exclusiveItem ^ stepSign < 0) { + arr(i) = fromLong(t / scale) + t = timestampAddInterval(t, stepMonths, stepMicros, timeZone) + i += 1 } - new GenericArrayData(finalData) + + // truncate array to the correct length + if (arr.length == i) arr else arr.slice(0, i) + } + } + + override def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String = { + val stepMonths = ctx.freshName("stepMonths") + val stepMicros = ctx.freshName("stepMicros") + val stepScaled = ctx.freshName("stepScaled") + val intervalInMicros = ctx.freshName("intervalInMicros") + val startMicros = ctx.freshName("startMicros") + val stopMicros = ctx.freshName("stopMicros") + val arrLength = ctx.freshName("arrLength") + val stepSign = ctx.freshName("stepSign") + val exclusiveItem = ctx.freshName("exclusiveItem") + val t = ctx.freshName("t") + val i = ctx.freshName("i") + val genTimeZone = ctx.addReferenceObj("timeZone", timeZone, classOf[TimeZone].getName) + + val sequenceLengthCode = + s""" + |final long $intervalInMicros = $stepMicros + $stepMonths * ${microsPerMonth}L; + |${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)} + """.stripMargin + + val timestampAddIntervalCode = + s""" + |$t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval( + | $t, $stepMonths, $stepMicros, $genTimeZone); + """.stripMargin + + s""" + |final int $stepMonths = $step.months; + |final long $stepMicros = $step.microseconds; + | + |if ($stepMonths == 0) { + | final $elemType $stepScaled = ($elemType) ($stepMicros / ${scale}L); + | ${backedSequenceImpl.genCode(ctx, start, stop, stepScaled, arr, elemType)}; + | + |} else { + | final long $startMicros = $start * ${scale}L; + | final long $stopMicros = $stop * ${scale}L; + | + | $sequenceLengthCode + | + | final int $stepSign = $stopMicros > $startMicros ? +1 : -1; + | final long $exclusiveItem = $stopMicros + $stepSign; + | + | $arr = new $elemType[$arrLength]; + | long $t = $startMicros; + | int $i = 0; + | + | while ($t < $exclusiveItem ^ $stepSign < 0) { + | $arr[$i] = ($elemType) ($t / ${scale}L); + | $timestampAddIntervalCode + | $i += 1; + | } + | + | if ($arr.length > $i) { + | $arr = java.util.Arrays.copyOf($arr, $i); + | } + |} + """.stripMargin + } + } + + private def getSequenceLength[U](start: U, stop: U, step: U)(implicit num: Integral[U]): Int = { + import num._ + require( + (step > num.zero && start <= stop) + || (step < num.zero && start >= stop) + || (step == num.zero && start == stop), + s"Illegal sequence boundaries: $start to $stop by $step") + + val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / step.toLong + + require( + len <= MAX_ROUNDED_ARRAY_LENGTH, + s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + + len.toInt + } + + private def genSequenceLengthCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + len: String): String = { + val longLen = ctx.freshName("longLen") + s""" + |if (!(($step > 0 && $start <= $stop) || + | ($step < 0 && $start >= $stop) || + | ($step == 0 && $start == $stop))) { + | throw new IllegalArgumentException( + | "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step); + |} + |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $step; + |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) { + | throw new IllegalArgumentException( + | "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH"); + |} + |int $len = (int) $longLen; + """.stripMargin + } +} + +/** + * Returns the array containing the given input value (left) count (right) times. + */ +@ExpressionDescription( + usage = "_FUNC_(element, count) - Returns the array containing element count times.", + examples = """ + Examples: + > SELECT _FUNC_('123', 2); + ['123', '123'] + """, + since = "2.4.0") +case class ArrayRepeat(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) + + override def nullable: Boolean = right.nullable + + override def eval(input: InternalRow): Any = { + val count = right.eval(input) + if (count == null) { + null + } else { + if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + + s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + } + val element = left.eval(input) + new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) + } + } + + override def prettyName: String = "array_repeat" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val element = leftGen.value + val count = rightGen.value + val et = dataType.elementType + + val coreLogic = if (CodeGenerator.isPrimitiveType(et)) { + genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value) + } else { + genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value) + } + val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) + + ev.copy(code = + code""" + |boolean ${ev.isNull} = false; + |${leftGen.code} + |${rightGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = + | ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """.stripMargin) + } + + private def nullElementsProtection( + ev: ExprCode, + rightIsNull: String, + coreLogic: String): String = { + if (nullable) { + s""" + |if ($rightIsNull) { + | ${ev.isNull} = true; + |} else { + | ${coreLogic} + |} + """.stripMargin + } else { + coreLogic + } + } + + private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = { + val numElements = ctx.freshName("numElements") + val numElementsCode = + s""" + |int $numElements = 0; + |if ($count > 0) { + | $numElements = $count; + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + + (numElements, numElementsCode) + } + + private def genCodeForPrimitiveElement( + ctx: CodegenContext, + elementType: DataType, + element: String, + count: String, + leftIsNull: String, + arrayDataName: String): String = { + val tempArrayDataName = ctx.freshName("tempArrayData") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val errorMessage = s" $prettyName failed." + val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + + s""" + |$numElemCode + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)} + |if (!$leftIsNull) { + | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { + | $tempArrayDataName.set$primitiveValueTypeName(k, $element); + | } + |} else { + | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { + | $tempArrayDataName.setNullAt(k); + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForNonPrimitiveElement( + ctx: CodegenContext, + element: String, + count: String, + leftIsNull: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |if (!$leftIsNull) { + | for (int k = 0; k < $numElemName; k++) { + | $arrayName[k] = $element; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + +} + +/** + * Remove all elements that equal to element from the given array + */ +@ExpressionDescription( + usage = "_FUNC_(array, element) - Remove all elements that equal to element from array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3), 3); + [1,2,null] + """, since = "2.4.0") +case class ArrayRemove(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = left.dataType + + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } + + lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + + override def nullSafeEval(arr: Any, value: Any): Any = { + val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements()) + var pos = 0 + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null || !ordering.equiv(v, value)) { + newArray(pos) = v + pos += 1 } + ) + new GenericArrayData(newArray.slice(0, pos)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val numsToRemove = ctx.freshName("numsToRemove") + val newArraySize = ctx.freshName("newArraySize") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) + s""" + |int $numsToRemove = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && $isEqual) { + | $numsToRemove = $numsToRemove + 1; + | } + |} + |int $newArraySize = $arr.numElements() - $numsToRemove; + |${genCodeForResult(ctx, ev, arr, value, newArraySize)} + """.stripMargin + }) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val evals = children.map(_.genCode(ctx)) - val args = ctx.freshName("args") - - val inputs = evals.zipWithIndex.map { case (eval, index) => + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + value: String, + newArraySize: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val getValue = CodeGenerator.getValue(inputArray, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName s""" - ${eval.code} - if (!${eval.isNull}) { - $args[$index] = ${eval.value}; - } - """ + |int $pos = 0; + |Object[] $values = new Object[$newArraySize]; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values[$pos] = null; + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $values[$pos] = $getValue; + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")} + |int $pos = 0; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values.setNullAt($pos); + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $values.set$primitiveValueTypeName($pos, $getValue); + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = $values; + """.stripMargin } + } - val (concatenator, initCode) = dataType match { - case BinaryType => - (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") - case StringType => - ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") - case ArrayType(elementType, _) => - val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForPrimitiveArrays(ctx, elementType) + override def prettyName: String = "array_remove" +} + +/** + * Removes duplicate values from the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Removes duplicate values from the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3)); + [1,2,3,null] + """, since = "2.4.0") +case class ArrayDistinct(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def dataType: DataType = child.dataType + + @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + } + } + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + override def nullSafeEval(array: Any): Any = { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementTypeSupportEquals) { + new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } else { + var foundNullElement = false + var pos = 0 + for (i <- 0 until data.length) { + if (data(i) == null) { + if (!foundNullElement) { + foundNullElement = true + pos = pos + 1 + } } else { - genCodeForNonPrimitiveArrays(ctx, elementType) + var j = 0 + var done = false + while (j <= i && !done) { + if (data(j) != null && ordering.equiv(data(j), data(i))) { + done = true + } + j = j + 1 + } + if (i == j - 1) { + pos = pos + 1 + } } - (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") + } + new GenericArrayData(data.slice(0, pos)) } - val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = inputs, - funcName = "valueConcat", - extraArguments = (s"$javaType[]", args) :: Nil) - ev.copy(s""" - $initCode - $codes - $javaType ${ev.value} = $concatenator.concat($args); - boolean ${ev.isNull} = ${ev.value} == null; - """) } - private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { - val numElements = ctx.freshName("numElements") - val code = s""" - |long $numElements = 0L; - |for (int z = 0; z < ${children.length}; z++) { - | $numElements += args[z].numElements(); - |} - |if ($numElements > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); - |} - """.stripMargin + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (array) => { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") + val getValue1 = CodeGenerator.getValue(array, elementType, i) + val getValue2 = CodeGenerator.getValue(array, elementType, j) + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val hs = ctx.freshName("hs") + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + if (elementTypeSupportEquals) { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add($getValue1); + | } + |} + |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } else { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | if (!($foundNullElement)) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | $foundNullElement = true; + | } + | } else { + | int $j; + | for ($j = 0; $j < $i; $j ++) { + | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) { + | break; + | } + | } + | if ($i == $j) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | } + | } + |} + | + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } + }) + } - (code, numElements) + private def setNull( + isPrimitive: Boolean, + foundNullElement: String, + distinctArray: String, + pos: String): String = { + val setNullValue = + if (!isPrimitive) { + s"$distinctArray[$pos] = null"; + } else { + s"$distinctArray.setNullAt($pos)"; + } + + s""" + |if (!($foundNullElement)) { + | $setNullValue; + | $pos = $pos + 1; + | $foundNullElement = true; + |} + """.stripMargin } - private def nullArgumentProtection() : String = { - if (nullable) { + private def setNotNullValue(isPrimitive: Boolean, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + if (!isPrimitive) { + s"$distinctArray[$pos] = $getValue1"; + } else { + s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)"; + } + } + + private def setValueForFastEval( + isPrimitive: Boolean, + hs: String, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |if (!($hs.contains($getValue1))) { + | $hs.add($getValue1); + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + private def setValueForBruteForceEval( + isPrimitive: Boolean, + i: String, + j: String, + inputArray: String, + distinctArray: String, + pos: String, + getValue1: String, + isEqual: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |int $j; + |for ($j = 0; $j < $i; $j ++) { + | if (!$inputArray.isNullAt($j) && $isEqual) { + | break; + | } + |} + |if ($i == $j) { + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + size: String): String = { + val distinctArray = ctx.freshName("distinctArray") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val pos = ctx.freshName("pos") + val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) + val getValue2 = CodeGenerator.getValue(inputArray, elementType, j) + val isEqual = ctx.genEqual(elementType, getValue1, getValue2) + val foundNullElement = ctx.freshName("foundNullElement") + val hs = ctx.freshName("hs") + val openHashSet = classOf[OpenHashSet[_]].getName + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + val setNullForNonPrimitive = + setNull(false, foundNullElement, distinctArray, pos) + if (elementTypeSupportEquals) { + val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } else { + val setValueForBruteForce = setValueForBruteForceEval( + false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } else { + | $setValueForBruteForce; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos) + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" + val setValueForFast = + setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) s""" - |for (int z = 0; z < ${children.length}; z++) { - | if (args[z] == null) return null; + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForPrimitive; + | } else { + | $setValueForFast; + | } |} - """.stripMargin + |${ev.value} = $distinctArray; + """.stripMargin + } + } + + override def prettyName: String = "array_distinct" +} + +/** + * Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept. + */ +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { + override def dataType: DataType = { + val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType]) + ArrayType(elementType, dataTypes.exists(_.containsNull)) + } + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, + s"function $prettyName") } else { - "" + typeCheckResult } } - private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val arrayName = ctx.freshName("array") - val arraySizeName = ctx.freshName("size") - val counter = ctx.freshName("counter") - val arrayData = ctx.freshName("arrayData") + @transient protected lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + @transient protected lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } +} - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + - | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + - | " for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) +object ArraySetLike { + def throwUnionLengthOverflowException(length: Int): Unit = { + throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + + s"elements due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } +} - s""" - |new Object() { - | public ArrayData concat($javaType[] args) { - | ${nullArgumentProtection()} - | $numElemCode - | $unsafeArraySizeInBytes - | byte[] $arrayName = new byte[(int)$arraySizeName]; - | UnsafeArrayData $arrayData = new UnsafeArrayData(); - | Platform.putLong($arrayName, $baseOffset, $numElemName); - | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | if (args[y].isNullAt(z)) { - | $arrayData.setNullAt($counter); - | } else { - | $arrayData.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} - | ); - | } - | $counter++; - | } - | } - | return $arrayData; - | } - |}""".stripMargin.stripPrefix("\n") + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { + var hsInt: OpenHashSet[Int] = _ + var hsLong: OpenHashSet[Long] = _ + + def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getInt(idx) + if (!hsInt.contains(elem)) { + if (resultArray != null) { + resultArray.setInt(pos, elem) + } + hsInt.add(elem) + true + } else { + false + } } - private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val genericArrayClass = classOf[GenericArrayData].getName - val arrayData = ctx.freshName("arrayObjects") - val counter = ctx.freshName("counter") + def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getLong(idx) + if (!hsLong.contains(elem)) { + if (resultArray != null) { + resultArray.setLong(pos, elem) + } + hsLong.add(elem) + true + } else { + false + } + } - val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + def evalIntLongPrimitiveType( + array1: ArrayData, + array2: ArrayData, + resultArray: ArrayData, + isLongType: Boolean): Int = { + // store elements into resultArray + var nullElementSize = 0 + var pos = 0 + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + val size = if (!isLongType) hsInt.size else hsLong.size + if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(size) + } + if (array.isNullAt(i)) { + if (nullElementSize == 0) { + if (resultArray != null) { + resultArray.setNullAt(pos) + } + pos += 1 + nullElementSize = 1 + } + } else { + val assigned = if (!isLongType) { + assignInt(array, i, resultArray, pos) + } else { + assignLong(array, i, resultArray, pos) + } + if (assigned) { + pos += 1 + } + } + i += 1 + } + } + pos + } - s""" - |new Object() { - | public ArrayData concat($javaType[] args) { - | ${nullArgumentProtection()} - | $numElemCode - | Object[] $arrayData = new Object[(int)$numElemName]; - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; - | $counter++; - | } - | } - | return new $genericArrayClass($arrayData); - | } - |}""".stripMargin.stripPrefix("\n") + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + if (elementTypeSupportEquals) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + // calculate result array size + hsInt = new OpenHashSet[Int] + val elements = evalIntLongPrimitiveType(array1, array2, null, false) + hsInt = new OpenHashSet[Int] + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( + IntegerType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) + } + evalIntLongPrimitiveType(array1, array2, resultArray, false) + resultArray + case LongType => + // avoid boxing of primitive long array elements + // calculate result array size + hsLong = new OpenHashSet[Long] + val elements = evalIntLongPrimitiveType(array1, array2, null, true) + hsLong = new OpenHashSet[Long] + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( + LongType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) + } + evalIntLongPrimitiveType(array1, array2, resultArray, true) + resultArray + case _ => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new OpenHashSet[Any] + var foundNullElement = false + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + arrayBuffer += null + foundNullElement = true + } + } else { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 + } + } + new GenericArrayData(arrayBuffer) + } + } else { + ArrayUnion.unionOrdering(array1, array2, elementType, ordering) + } } - override def toString: String = s"concat(${children.mkString(", ")})" + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = + if (elementTypeSupportEquals) { + elementType match { + case ByteType | ShortType | IntegerType | LongType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", + if (elementType == LongType) "Long" else "Int", + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + if (elementType == LongType) "(long)" else "(int)", + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) + ("", "Object", + s"get($i, $et)", s"update($pos, $value)", "Object", "", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + } else { + ("", "", "", "", "", "", "") + } - override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" + nullSafeCodeGen(ctx, ev, (array1, array2) => { + if (openHashElementType != "") { + // Here, we ensure elementTypeSupportEquals is true + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" + val hs = ctx.freshName("hs") + val arrayData = classOf[ArrayData].getName + val arrays = ctx.freshName("arrays") + val array = ctx.freshName("array") + val arrayDataIdx = ctx.freshName("arrayDataIdx") + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |boolean $foundNullElement = false; + |$arrayData[] $arrays = new $arrayData[]{$array1, $array2}; + |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { + | $arrayData $array = $arrays[$arrayDataIdx]; + | for (int $i = 0; $i < $array.numElements(); $i++) { + | if ($array.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add$postFix($array.$getter); + | } + | } + |} + |int $size = $hs.size() + ($foundNullElement ? 1 : 0); + |$arrayBuilder + |$hs = new $openHashSet$postFix($classTag); + |$foundNullElement = false; + |int $pos = 0; + |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { + | $arrayData $array = $arrays[$arrayDataIdx]; + | for (int $i = 0; $i < $array.numElements(); $i++) { + | if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | ${ev.value}.setNullAt($pos++); + | $foundNullElement = true; + | } + | } else { + | $javaTypeName $value = $array.$getter; + | if (!$hs.contains($castOp $value)) { + | $hs.add$postFix($value); + | ${ev.value}.$setter; + | $pos++; + | } + | } + | } + |} + """.stripMargin + } else { + val arrayUnion = classOf[ArrayUnion].getName + val et = ctx.addReferenceObj("elementTypeUnion", elementType) + val order = ctx.addReferenceObj("orderingUnion", ordering) + val method = "unionOrdering" + s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);" + } + }) + } + + override def prettyName: String = "array_union" +} + +object ArrayUnion { + def unionOrdering( + array1: ArrayData, + array2: ArrayData, + elementType: DataType, + ordering: Ordering[Any]): ArrayData = { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } + } else { + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } + } + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + })) + new GenericArrayData(arrayBuffer) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67876a8565488..0a5f8a907b50a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,7 +21,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods @@ -63,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + assigns + postprocess, + code = code"${preprocess}${assigns}${postprocess}", value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } @@ -219,7 +220,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) = GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false) val code = - s""" + code""" final boolean ${ev.isNull} = false; $preprocessKeyData $assignKeys @@ -235,6 +236,76 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def prettyName: String = "map" } +/** + * Returns a catalyst Map containing the two arrays in children expressions as keys and values. + */ +@ExpressionDescription( + usage = """ + _FUNC_(keys, values) - Creates a map with a pair of the given key/value arrays. All elements + in keys should not be null""", + examples = """ + Examples: + > SELECT _FUNC_([1.0, 3.0], ['2', '4']); + {1.0:"2",3.0:"4"} + """, since = "2.4.0") +case class MapFromArrays(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + + override def dataType: DataType = { + MapType( + keyType = left.dataType.asInstanceOf[ArrayType].elementType, + valueType = right.dataType.asInstanceOf[ArrayType].elementType, + valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) + } + + override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { + val keyArrayData = keyArray.asInstanceOf[ArrayData] + val valueArrayData = valueArray.asInstanceOf[ArrayData] + if (keyArrayData.numElements != valueArrayData.numElements) { + throw new RuntimeException("The given two arrays should have the same length") + } + val leftArrayType = left.dataType.asInstanceOf[ArrayType] + if (leftArrayType.containsNull) { + var i = 0 + while (i < keyArrayData.numElements) { + if (keyArrayData.isNullAt(i)) { + throw new RuntimeException("Cannot use null as map key!") + } + i += 1 + } + } + new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy()) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => { + val arrayBasedMapData = classOf[ArrayBasedMapData].getName + val leftArrayType = left.dataType.asInstanceOf[ArrayType] + val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else { + val i = ctx.freshName("i") + s""" + |for (int $i = 0; $i < $keyArrayData.numElements(); $i++) { + | if ($keyArrayData.isNullAt($i)) { + | throw new RuntimeException("Cannot use null as map key!"); + | } + |} + """.stripMargin + } + s""" + |if ($keyArrayData.numElements() != $valueArrayData.numElements()) { + | throw new RuntimeException("The given two arrays should have the same length"); + |} + |$keyArrayElemNullCheck + |${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy()); + """.stripMargin + }) + } + + override def prettyName: String = "map_from_arrays" +} + /** * An expression representing a not yet available attribute name. This expression is unevaluable * and as its name suggests it is a temporary place holder until we're able to determine the @@ -373,7 +444,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc extraArguments = "Object[]" -> values :: Nil) ev.copy(code = - s""" + code""" |Object[] $values = new Object[${valExprs.size}]; |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3fba52d745453..99671d5b863c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -273,7 +273,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -282,7 +282,7 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy var i = 0 var found = false while (i < length && !found) { - if (keys.get(i, keyType) == ordinal) { + if (ordering.equiv(keys.get(i, keyType), ordinal)) { found = true } else { i += 1 @@ -345,8 +345,19 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy case class GetMapValue(child: Expression, key: Expression) extends GetMapValueUtil with ExtractValue with NullIntolerant { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(keyType) + private def keyType = child.dataType.asInstanceOf[MapType].keyType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName") + } + } + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) @@ -363,7 +374,7 @@ case class GetMapValue(child: Expression, key: Expression) // todo: current search is O(n), improve it. override def nullSafeEval(value: Any, ordinal: Any): Any = { - getValueEval(value, ordinal, keyType) + getValueEval(value, ordinal, keyType, ordering) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 205d77f6a9acf..30ce9e4743da9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ // scalastyle:off line.size.limit @@ -32,7 +33,12 @@ import org.apache.spark.sql.types._ """) // scalastyle:on line.size.limit case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { + extends ComplexTypeMergingExpression { + + @transient + override lazy val inputTypesForMerging: Seq[DataType] = { + Seq(trueValue.dataType, falseValue.dataType) + } override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable @@ -42,7 +48,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( "type of predicate expression in If should be boolean, " + s"not ${predicate.dataType.simpleString}") - } else if (!trueValue.dataType.sameType(falseValue.dataType)) { + } else if (!areInputTypesForMergingEqual) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { @@ -50,8 +56,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def dataType: DataType = trueValue.dataType - override def eval(input: InternalRow): Any = { if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) { trueValue.eval(input) @@ -66,7 +70,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val falseEval = falseValue.genCode(ctx) val code = - s""" + code""" |${condEval.code} |boolean ${ev.isNull} = false; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -117,27 +121,24 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi case class CaseWhen( branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) - extends Expression with Serializable { + extends ComplexTypeMergingExpression with Serializable { override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue // both then and else expressions should be considered. - def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType) - - def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall { - case Seq(dt1, dt2) => dt1.sameType(dt2) + @transient + override lazy val inputTypesForMerging: Seq[DataType] = { + branches.map(_._2.dataType) ++ elseValue.map(_.dataType) } - override def dataType: DataType = branches.head._2.dataType - override def nullable: Boolean = { // Result is nullable if any of the branch is nullable, or if the else value is nullable branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true) } override def checkInputDataTypes(): TypeCheckResult = { - // Make sure all branch conditions are boolean types. - if (valueTypesEqual) { + if (areInputTypesForMergingEqual) { + // Make sure all branch conditions are boolean types. if (branches.forall(_._1.dataType == BooleanType)) { TypeCheckResult.TypeCheckSuccess } else { @@ -265,7 +266,7 @@ case class CaseWhen( }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes @@ -293,7 +294,7 @@ object CaseWhen { case cond :: value :: Nil => Some((cond, value)) case value :: Nil => None }.toArray.toSeq // force materialization to make the seq serializable - val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } } @@ -308,7 +309,7 @@ object CaseKeyWhen { case Seq(cond, value) => Some((EqualTo(key, cond), value)) case Seq(value) => None }.toArray.toSeq // force materialization to make the seq serializable - val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index b9b2cd5bdb9f0..08838d2b2c612 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -717,7 +718,7 @@ abstract class UnixTime } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -746,7 +747,7 @@ abstract class UnixTime }) case TimestampType => val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -757,7 +758,7 @@ abstract class UnixTime val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -852,7 +853,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1016,6 +1017,48 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S } } +/** + * A special expression used to convert the string input of `to/from_utc_timestamp` to timestamp, + * which requires the timestamp string to not have timezone information, otherwise null is returned. + */ +case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def dataType: DataType = TimestampType + override def nullable: Boolean = true + override def toString: String = child.toString + override def sql: String = child.sql + + override def nullSafeEval(input: Any): Any = { + DateTimeUtils.stringToTimestamp( + input.asInstanceOf[UTF8String], timeZone, rejectTzInString = true).orNull + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = ctx.addReferenceObj("timeZone", timeZone) + val longOpt = ctx.freshName("longOpt") + val eval = child.genCode(ctx) + val code = code""" + |${eval.code} + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; + |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; + |if (!${eval.isNull}) { + | scala.Option $longOpt = $dtu.stringToTimestamp(${eval.value}, $tz, true); + | if ($longOpt.isDefined()) { + | ${ev.value} = ((Long) $longOpt.get()).longValue(); + | ${ev.isNull} = false; + | } + |} + """.stripMargin + ev.copy(code = code) + } +} + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield @@ -1048,7 +1091,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1062,7 +1105,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1152,42 +1195,61 @@ case class AddMonths(startDate: Expression, numMonths: Expression) } /** - * Returns number of months between dates date1 and date2. + * Returns number of months between times `timestamp1` and `timestamp2`. + * If `timestamp1` is later than `timestamp2`, then the result is positive. + * If `timestamp1` and `timestamp2` are on the same day of month, or both + * are the last day of month, time of day will be ignored. Otherwise, the + * difference is calculated based on 31 days per month, and rounded to + * 8 digits unless roundOff=false. */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(timestamp1, timestamp2) - Returns number of months between `timestamp1` and `timestamp2`.", + usage = """ + _FUNC_(timestamp1, timestamp2[, roundOff]) - If `timestamp1` is later than `timestamp2`, then the result + is positive. If `timestamp1` and `timestamp2` are on the same day of month, or both + are the last day of month, time of day will be ignored. Otherwise, the difference is + calculated based on 31 days per month, and rounded to 8 digits unless roundOff=false. + """, examples = """ Examples: > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30'); 3.94959677 + > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30', false); + 3.9495967741935485 """, since = "1.5.0") // scalastyle:on line.size.limit -case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { +case class MonthsBetween( + date1: Expression, + date2: Expression, + roundOff: Expression, + timeZoneId: Option[String] = None) + extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None) - def this(date1: Expression, date2: Expression) = this(date1, date2, None) + def this(date1: Expression, date2: Expression, roundOff: Expression) = + this(date1, date2, roundOff, None) - override def left: Expression = date1 - override def right: Expression = date2 + override def children: Seq[Expression] = Seq(date1, date2, roundOff) - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType, BooleanType) override def dataType: DataType = DoubleType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(t1: Any, t2: Any): Any = { - DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long], timeZone) + override def nullSafeEval(t1: Any, t2: Any, roundOff: Any): Any = { + DateTimeUtils.monthsBetween( + t1.asInstanceOf[Long], t2.asInstanceOf[Long], roundOff.asInstanceOf[Boolean], timeZone) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (l, r) => { - s"""$dtu.monthsBetween($l, $r, $tz)""" + defineCodeGen(ctx, ev, (d1, d2, roundOff) => { + s"""$dtu.monthsBetween($d1, $d2, $roundOff, $tz)""" }) } @@ -1226,7 +1288,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1240,7 +1302,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1383,13 +1445,13 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1471,14 +1533,14 @@ case class TruncDate(date: Expression, format: Expression) """, examples = """ Examples: - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR'); - 2015-01-01T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM'); - 2015-03-01T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD'); - 2015-03-05T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR'); - 2015-03-05T09:00:00 + > SELECT _FUNC_('YEAR', '2015-03-05T09:32:05.359'); + 2015-01-01 00:00:00 + > SELECT _FUNC_('MM', '2015-03-05T09:32:05.359'); + 2015-03-01 00:00:00 + > SELECT _FUNC_('DD', '2015-03-05T09:32:05.359'); + 2015-03-05 00:00:00 + > SELECT _FUNC_('HOUR', '2015-03-05T09:32:05.359'); + 2015-03-05 09:00:00 """, since = "2.3.0") // scalastyle:on line.size.limit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index db1579ba28671..04de83343be71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} import org.apache.spark.sql.types._ /** @@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def eval(input: InternalRow): Any = child.eval(input) /** Just a simple pass-through for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + ev.copy(EmptyBlock) override def prettyName: String = "promote_precision" override def sql: String = child.sql override lazy val canonicalized: Expression = child.canonicalized diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3af4bfebad45e..b7c52f1d7b40a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -215,7 +216,7 @@ case class Stack(children: Seq[Expression]) extends Generator { // Create the collection. val wrapperClass = classOf[mutable.WrappedArray[_]].getName ev.copy(code = - s""" + code""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); """.stripMargin, isNull = FalseLiteral) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index ef790338bdd27..cec00b66f873c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,6 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -293,7 +294,7 @@ abstract class HashExpression[E] extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |$hashResultType ${ev.value} = $seed; |$codes """.stripMargin) @@ -674,7 +675,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 2a3cc580273ee..3b0141ad52cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();", + isNull = FalseLiteral) } } @@ -65,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral) } } @@ -88,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index fdd672c416a03..63943b1a4351a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter} +import java.io._ import scala.util.parsing.combinator.RegexParsers @@ -28,8 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -527,22 +527,24 @@ case class JsonToStructs( override def nullable: Boolean = true // Used in `FunctionRegistry` - def this(child: Expression, schema: Expression) = + def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), - options = Map.empty[String, String], + schema = JsonExprUtils.evalSchemaExpr(schema), + options = options, child = child, timeZoneId = None) + def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String]) + def this(child: Expression, schema: Expression, options: Expression) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), + schema = JsonExprUtils.evalSchemaExpr(schema), options = JsonExprUtils.convertToMapData(options), child = child, timeZoneId = None) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { - case _: StructType | ArrayType(_: StructType, _) => + case _: StructType | ArrayType(_: StructType, _) | _: MapType => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") @@ -552,6 +554,7 @@ case class JsonToStructs( lazy val rowSchema = nullableSchema match { case st: StructType => st case ArrayType(st: StructType, _) => st + case mt: MapType => mt } // This converts parsed rows to the desired output by the given schema. @@ -561,6 +564,8 @@ case class JsonToStructs( (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null case ArrayType(_: StructType, _) => (rows: Seq[InternalRow]) => new GenericArrayData(rows) + case _: MapType => + (rows: Seq[InternalRow]) => rows.head.getMap(0) } @transient @@ -607,6 +612,11 @@ case class JsonToStructs( } override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + + override def sql: String = schema match { + case _: MapType => "entries" + case _ => super.sql + } } /** @@ -731,11 +741,44 @@ case class StructsToJson( override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil } +/** + * A function infers schema of JSON string. + */ +@ExpressionDescription( + usage = "_FUNC_(json[, options]) - Returns schema in the DDL format of JSON string.", + examples = """ + Examples: + > SELECT _FUNC_('[{"col":0}]'); + array> + """, + since = "2.4.0") +case class SchemaOfJson(child: Expression) + extends UnaryExpression with String2StringExpression with CodegenFallback { + + private val jsonOptions = new JSONOptions(Map.empty, "UTC") + private val jsonFactory = new JsonFactory() + jsonOptions.setJacksonOptions(jsonFactory) + + override def convert(v: UTF8String): UTF8String = { + val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, v)) { parser => + parser.nextToken() + inferField(parser, jsonOptions) + } + + UTF8String.fromString(dt.catalogString) + } +} + object JsonExprUtils { - def validateSchemaLiteral(exp: Expression): StructType = exp match { - case Literal(s, StringType) => CatalystSqlParser.parseTableSchema(s.toString) - case e => throw new AnalysisException(s"Expected a string literal instead of $e") + def evalSchemaExpr(exp: Expression): DataType = exp match { + case Literal(s, StringType) => DataType.fromDDL(s.toString) + case e @ SchemaOfJson(_: Literal) => + val ddlSchema = e.eval().asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) + case e => throw new AnalysisException( + "Schema should be specified in DDL format as a string literal" + + s" or output of the schema_of_json function instead of ${e.sql}") } def convertToMapData(exp: Expression): Map[String, String] = exp match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 246025b82d59e..0cc2a332f2c30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -57,6 +57,7 @@ object Literal { case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) + case c: Char => Literal(UTF8String.fromString(c.toString), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala new file mode 100644 index 0000000000000..276a57266a6e0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -0,0 +1,569 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.codec.digest.DigestUtils + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.MaskExpressionsUtils._ +import org.apache.spark.sql.catalyst.expressions.MaskLike._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +trait MaskLike { + def upper: String + def lower: String + def digit: String + + protected lazy val upperReplacement: Int = getReplacementChar(upper, defaultMaskedUppercase) + protected lazy val lowerReplacement: Int = getReplacementChar(lower, defaultMaskedLowercase) + protected lazy val digitReplacement: Int = getReplacementChar(digit, defaultMaskedDigit) + + protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName + + def inputStringLengthCode(inputString: String, length: String): String = { + s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());" + } + + def appendMaskedToStringBuilderCode( + ctx: CodegenContext, + sb: String, + inputString: String, + offset: String, + numChars: String): String = { + val i = ctx.freshName("i") + val codePoint = ctx.freshName("codePoint") + s""" + |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { + | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + | $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint, + | $upperReplacement, $lowerReplacement, + | $digitReplacement, $defaultMaskedOther)); + | $offset += Character.charCount($codePoint); + |} + """.stripMargin + } + + def appendUnchangedToStringBuilderCode( + ctx: CodegenContext, + sb: String, + inputString: String, + offset: String, + numChars: String): String = { + val i = ctx.freshName("i") + val codePoint = ctx.freshName("codePoint") + s""" + |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { + | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + | $sb.appendCodePoint($codePoint); + | $offset += Character.charCount($codePoint); + |} + """.stripMargin + } + + def appendMaskedToStringBuilder( + sb: java.lang.StringBuilder, + inputString: String, + startOffset: Int, + numChars: Int): Int = { + var offset = startOffset + (1 to numChars) foreach { _ => + val codePoint = inputString.codePointAt(offset) + sb.appendCodePoint(transformChar( + codePoint, + upperReplacement, + lowerReplacement, + digitReplacement, + defaultMaskedOther)) + offset += Character.charCount(codePoint) + } + offset + } + + def appendUnchangedToStringBuilder( + sb: java.lang.StringBuilder, + inputString: String, + startOffset: Int, + numChars: Int): Int = { + var offset = startOffset + (1 to numChars) foreach { _ => + val codePoint = inputString.codePointAt(offset) + sb.appendCodePoint(codePoint) + offset += Character.charCount(codePoint) + } + offset + } +} + +trait MaskLikeWithN extends MaskLike { + def n: Int + protected lazy val charCount: Int = if (n < 0) 0 else n +} + +/** + * Utils for mask operations. + */ +object MaskLike { + val defaultCharCount = 4 + val defaultMaskedUppercase: Int = 'X' + val defaultMaskedLowercase: Int = 'x' + val defaultMaskedDigit: Int = 'n' + val defaultMaskedOther: Int = MaskExpressionsUtils.UNMASKED_VAL + + def extractCharCount(e: Expression): Int = e match { + case Literal(i, IntegerType | NullType) => + if (i == null) defaultCharCount else i.asInstanceOf[Int] + case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + + s"${IntegerType.simpleString}, but got literal of ${dt.simpleString}") + case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") + } + + def extractReplacement(e: Expression): String = e match { + case Literal(s, StringType | NullType) => if (s == null) null else s.toString + case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + + s"${StringType.simpleString}, but got literal of ${dt.simpleString}") + case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") + } +} + +/** + * Masks the input string. Additional parameters can be set to change the masking chars for + * uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, upper[, lower[, digit]]]) - Masks str. By default, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("abcd-EFGH-8765-4321", "U", "l", "#"); + llll-UUUU-####-#### + """) +// scalastyle:on line.size.limit +case class Mask(child: Expression, upper: String, lower: String, digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLike { + + def this(child: Expression) = this(child, null.asInstanceOf[String], null, null) + + def this(child: Expression, upper: Expression) = + this(child, extractReplacement(upper), null, null) + + def this(child: Expression, upper: Expression, lower: Expression) = + this(child, extractReplacement(upper), extractReplacement(lower), null) + + def this(child: Expression, upper: Expression, lower: Expression, digit: Expression) = + this(child, extractReplacement(upper), extractReplacement(lower), extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val sb = new java.lang.StringBuilder(length) + appendMaskedToStringBuilder(sb, str, 0, length) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |StringBuilder $sb = new StringBuilder($length); + |${CodeGenerator.JAVA_INT} $offset = 0; + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)} + |${ev.value} = UTF8String.fromString($sb.toString()); + """.stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) +} + +/** + * Masks the first N chars of the input string. N defaults to 4. Additional parameters can be set + * to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + nnnn-5678-8765-4321 + """) +// scalastyle:on line.size.limit +case class MaskFirstN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val endOfMask = if (charCount > length) length else charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) + appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val endOfMask = ctx.freshName("endOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} + |${appendUnchangedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $endOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_first_n" +} + +/** + * Masks the last N chars of the input string. N defaults to 4. Additional parameters can be set + * to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + 1234-5678-8765-nnnn + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskLastN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val startOfMask = if (charCount >= length) 0 else length - charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) + appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val startOfMask = ctx.freshName("startOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ? + | 0 : $length - $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} + |${appendMaskedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $startOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_last_n" +} + +/** + * Masks all but the first N chars of the input string. N defaults to 4. Additional parameters can + * be set to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + 1234-nnnn-nnnn-nnnn + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskShowFirstN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val startOfMask = if (charCount > length) length else charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) + appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val startOfMask = ctx.freshName("startOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} + |${appendMaskedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $startOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_show_first_n" +} + +/** + * Masks all but the last N chars of the input string. N defaults to 4. Additional parameters can + * be set to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + nnnn-nnnn-nnnn-4321 + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskShowLastN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val endOfMask = if (charCount >= length) 0 else length - charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) + appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val endOfMask = ctx.freshName("endOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} + |${appendUnchangedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $endOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_show_last_n" +} + +/** + * Returns a hashed value based on str. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str) - Returns a hashed value based on str. The hash is consistent and can be used to join masked values together across tables.", + examples = """ + Examples: + > SELECT _FUNC_("abcd-EFGH-8765-4321"); + 60c713f5ec6912229d2060df1c322776 + """) +// scalastyle:on line.size.limit +case class MaskHash(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def nullSafeEval(input: Any): Any = { + UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[UTF8String].toString)) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val digestUtilsClass = classOf[DigestUtils].getName.stripSuffix("$") + s""" + |${ev.value} = UTF8String.fromString($digestUtilsClass.md5Hex($input.toString())); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_hash" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index bc4cfcec47425..c2e1720259b53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1191,11 +1192,11 @@ abstract class RoundBase(child: Expression, scale: Expression, val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 7eda65a867028..5d98dac46cf17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null or false. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - ExprCode(code = s"""${eval.code} + ExprCode(code = code"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, @@ -117,12 +118,13 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.", + usage = """_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.""", examples = """ Examples: > SELECT _FUNC_(); 46707d92-02f4-4817-8116-a4c3b23e6266 - """) + """, + note = "The function is non-deterministic.") // scalastyle:on line.size.limit case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful { @@ -150,7 +152,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta ctx.addPartitionInitializationStatement(s"$randomGen = " + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") - ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0787342bce6bc..2eeed3bbb2d91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -111,7 +112,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { @@ -232,7 +233,7 @@ case class IsNaN(child: Expression) extends UnaryExpression val eval = child.genCode(ctx) child.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) @@ -278,7 +279,7 @@ case class NaNvl(left: Expression, right: Expression) val rightGen = right.genCode(ctx) left.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${leftGen.code} boolean ${ev.isNull} = false; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -440,7 +441,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bc17d1229420a..2bf4203d0fec3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -33,8 +33,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -268,7 +270,7 @@ case class StaticInvoke( s"${ev.value} = $callFunc;" } - val code = s""" + val code = code""" $argCode $prepareIsNull $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -384,8 +386,7 @@ case class Invoke( """ } - val code = s""" - ${obj.code} + val code = obj.code + code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { @@ -491,7 +492,7 @@ case class NewInstance( s"new $className($argString)" } - val code = s""" + val code = code""" $argCode ${outer.map(_.code).getOrElse("")} final $javaType ${ev.value} = ${ev.isNull} ? @@ -531,9 +532,7 @@ case class UnwrapOption( val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); @@ -563,9 +562,7 @@ case class WrapOption(child: Expression, optType: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); @@ -934,8 +931,7 @@ case class MapObjects private( ) } - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1040,11 +1036,13 @@ case class CatalystToExternalMap private( private lazy val valueConverter = CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) - private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + private lazy val (newMapBuilderMethod, moduleField) = { val clazz = Utils.classForName(collClass.getCanonicalName + "$") - val module = clazz.getField("MODULE$").get(null) - val method = clazz.getMethod("newBuilder") - method.invoke(module).asInstanceOf[Builder[AnyRef, AnyRef]] + (clazz.getMethod("newBuilder"), clazz.getField("MODULE$").get(null)) + } + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + newMapBuilderMethod.invoke(moduleField).asInstanceOf[Builder[AnyRef, AnyRef]] } override def eval(input: InternalRow): Any = { @@ -1144,8 +1142,7 @@ case class CatalystToExternalMap private( """ val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1252,8 +1249,72 @@ case class ExternalMapToCatalyst private( override def dataType: MapType = MapType( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = { + val rowBuffer = InternalRow.fromSeq(Array[Any](1)) + def rowWrapper(data: Any): InternalRow = { + rowBuffer.update(0, data) + rowBuffer + } + + child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[java.util.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + val iter = data.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + val (key, value) = (entry.getKey, entry.getValue) + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 + } + (keys, values) + } + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[scala.collection.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + var i = 0 + for ((key, value) <- data) { + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 + } + (keys, values) + } + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (result != null) { + val (keys, values) = mapCatalystConverter(result) + new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) + } else { + null + } + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputMap = child.genCode(ctx) @@ -1324,9 +1385,8 @@ case class ExternalMapToCatalyst private( val mapCls = classOf[ArrayBasedMapData].getName val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) - val code = - s""" - ${inputMap.code} + val code = inputMap.code + + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); @@ -1404,7 +1464,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val schemaField = ctx.addReferenceObj("schema", schema) val code = - s""" + code""" |Object[] $values = new Object[${children.size}]; |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); @@ -1432,8 +1492,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ @@ -1465,8 +1524,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ @@ -1547,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp funcName = "initializeJavaBean", extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil) - val code = - s""" - |${instanceGen.code} + val code = instanceGen.code + + code""" |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value}; |if (!${instanceGen.isNull}) { | $initializeCode @@ -1597,9 +1654,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) // because errMsgField is used only when the value is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - val code = s""" - ${childGen.code} - + val code = childGen.code + code""" if (${childGen.isNull}) { throw new NullPointerException($errMsgField); } @@ -1642,7 +1697,7 @@ case class GetExternalRowField( // because errMsgField is used only when the field is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val row = child.genCode(ctx) - val code = s""" + val code = code""" ${row.code} if (${row.isNull}) { @@ -1670,13 +1725,36 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def nullable: Boolean = child.nullable - override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected) - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + private lazy val checkType: (Any) => Boolean = expected match { + case _: DecimalType => + (value: Any) => { + value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] || + value.isInstanceOf[Decimal] + } + case _: ArrayType => + (value: Any) => { + value.getClass.isArray || value.isInstanceOf[Seq[_]] + } + case _ => + val dataTypeClazz = ScalaReflection.javaBoxedType(dataType) + (value: Any) => { + dataTypeClazz.isInstance(value) + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (checkType(result)) { + result + } else { + throw new RuntimeException(s"${result.getClass.getName}$errMsg") + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the type doesn't match. @@ -1689,12 +1767,12 @@ case class ValidateExternalType(child: Expression, expected: DataType) Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") case _: ArrayType => - s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" + s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}" case _ => s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } - val code = s""" + val code = code""" ${input.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1a48995358af7..8a06daa37132d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.catalyst +import java.util.Locale + import com.google.common.collect.Maps +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -138,6 +142,88 @@ package object expressions { def indexOf(exprId: ExprId): Int = { Option(exprIdToOrdinal.get(exprId)).getOrElse(-1) } + + private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = { + m.mapValues(_.distinct).map(identity) + } + + /** Map to use for direct case insensitive attribute lookups. */ + @transient private lazy val direct: Map[String, Seq[Attribute]] = { + unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT))) + } + + /** Map to use for qualified case insensitive attribute lookups. */ + @transient private val qualified: Map[(String, String), Seq[Attribute]] = { + val grouped = attrs.filter(_.qualifier.isDefined).groupBy { a => + (a.qualifier.get.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Perform attribute resolution given a name and a resolver. */ + def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + // Collect matching attributes given a name and a lookup. + def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = { + candidates.toSeq.flatMap(_.collect { + case a if resolver(a.name, name) => a.withName(name) + }) + } + + // Find matches for the given name assuming that the 1st part is a qualifier (i.e. table name, + // alias, or subquery alias) and the 2nd part is the actual name. This returns a tuple of + // matched attributes and a list of parts that are to be resolved. + // + // For example, consider an example where "a" is the table name, "b" is the column name, + // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", + // and the second element will be List("c"). + val matches = nameParts match { + case qualifier +: name +: nestedFields => + val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified.get(key)).filter { a => + resolver(qualifier, a.qualifier.get) + } + (attributes, nestedFields) + case all => + (Nil, all) + } + + // If none of attributes match `table.column` pattern, we try to resolve it as a column. + val (candidates, nestedFields) = matches match { + case (Seq(), _) => + val name = nameParts.head + val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) + (attributes, nameParts.tail) + case _ => matches + } + + def name = UnresolvedAttribute(nameParts).name + candidates match { + case Seq(a) if nestedFields.nonEmpty => + // One match, but we also need to extract the requested nested field. + // The foldLeft adds ExtractValues for every remaining parts of the identifier, + // and aliased it with the last part of the name. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression) { (e, name) => + ExtractValue(e, Literal(name), resolver) + } + Some(Alias(fieldExprs, nestedFields.last)()) + + case Seq(a) => + // One match, no nested fields, use it. + Some(a) + + case Seq() => + // No matches. + None + + case ambiguousReferences => + // More than one match. + val referenceNames = ambiguousReferences.map(_.qualifiedName).mkString(", ") + throw new AnalysisException(s"Reference '$name' is ambiguous, could be: $referenceNames.") + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e195ec17f3bcf..f54103c4fbfba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,6 +22,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -36,6 +37,14 @@ object InterpretedPredicate { case class InterpretedPredicate(expression: Expression) extends BasePredicate { override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] + + override def initialize(partitionIndex: Int): Unit = { + super.initialize(partitionIndex) + expression.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + } + } } /** @@ -282,7 +291,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { }.mkString("\n")) ev.copy(code = - s""" + code""" |${valueGen.code} |byte $tmpResult = $HAS_NULL; |if (!${valueGen.isNull}) { @@ -346,7 +355,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with "" } ev.copy(code = - s""" + code""" |${childGen.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; @@ -398,7 +407,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with // The result should be `false`, if any of them is `false` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = false; @@ -407,7 +416,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = false; @@ -462,7 +471,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { ev.isNull = FalseLiteral - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = true; @@ -471,7 +480,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = true; @@ -613,7 +622,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) - ev.copy(code = eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + code""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 70186053617f8..926c2f00d430d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -68,7 +69,8 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful 0.8446490682263027 > SELECT _FUNC_(null); 0.8446490682263027 - """) + """, + note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit case class Rand(child: Expression) extends RDG { @@ -81,7 +83,7 @@ case class Rand(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = FalseLiteral) } @@ -96,7 +98,7 @@ object Rand { /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.", + usage = """_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.""", examples = """ Examples: > SELECT _FUNC_(); @@ -105,7 +107,8 @@ object Rand { 1.1164209726833079 > SELECT _FUNC_(null); 1.1164209726833079 - """) + """, + note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit case class Randn(child: Expression) extends RDG { @@ -118,7 +121,7 @@ case class Randn(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index ad0c0791d895f..7b68bb771faf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -123,7 +124,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -132,7 +133,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) @@ -198,7 +199,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -207,7 +208,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ea005a26a4c8b..bedad7da334ae 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -105,7 +106,7 @@ case class ConcatWs(children: Seq[Expression]) expressions = inputs, funcName = "valueConcatWs", extraArguments = ("UTF8String[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" UTF8String[] $args = new UTF8String[$numArgs]; ${separator.code} $codes @@ -149,7 +150,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code)) + val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString)) val varargCounts = ctx.splitExpressionsWithCurrentInputs( expressions = varargCount, @@ -176,7 +177,7 @@ case class ConcatWs(children: Seq[Expression]) foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( - s""" + code""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxVararg = 0; @@ -288,7 +289,7 @@ case class Elt(children: Seq[Expression]) extends Expression { }.mkString) ev.copy( - s""" + code""" |${index.code} |final int $indexVal = ${index.value}; |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; @@ -654,7 +655,7 @@ case class StringTrim( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -671,7 +672,7 @@ case class StringTrim( } else { ${ev.value} = ${srcString.value}.trim(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -754,7 +755,7 @@ case class StringTrimLeft( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -771,7 +772,7 @@ case class StringTrimLeft( } else { ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -856,7 +857,7 @@ case class StringTrimRight( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -873,7 +874,7 @@ case class StringTrimRight( } else { ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -1024,7 +1025,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) val substrGen = substr.genCode(ctx) val strGen = str.genCode(ctx) val startGen = start.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" int ${ev.value} = 0; boolean ${ev.isNull} = false; ${startGen.code} @@ -1350,7 +1351,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") val stringBuffer = classOf[StringBuffer].getName - ev.copy(code = s""" + ev.copy(code = code""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1915,12 +1916,15 @@ case class Encode(value: Expression, charset: Expression) usage = """ _FUNC_(expr1, expr2) - Formats the number `expr1` like '#,###,###.##', rounded to `expr2` decimal places. If `expr2` is 0, the result has no decimal point or fractional part. + `expr2` also accept a user specified format. This is supposed to function like MySQL's FORMAT. """, examples = """ Examples: > SELECT _FUNC_(12332.123456, 4); 12,332.1235 + > SELECT _FUNC_(12332.123456, '##################.###'); + 12332.123 """) case class FormatNumber(x: Expression, d: Expression) extends BinaryExpression with ExpectsInputTypes { @@ -1929,14 +1933,20 @@ case class FormatNumber(x: Expression, d: Expression) override def right: Expression = d override def dataType: DataType = StringType override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(NumericType, TypeCollection(IntegerType, StringType)) + + private val defaultFormat = "#,###,###,###,###,###,##0" // Associated with the pattern, for the last d value, and we will update the // pattern (DecimalFormat) once the new coming d value differ with the last one. // This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after // serialization (numberFormat has not been updated for dValue = 0). @transient - private var lastDValue: Option[Int] = None + private var lastDIntValue: Option[Int] = None + + @transient + private var lastDStringValue: Option[String] = None // A cached DecimalFormat, for performance concern, we will change it // only if the d value changed. @@ -1949,33 +1959,49 @@ case class FormatNumber(x: Expression, d: Expression) private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { - val dValue = dObject.asInstanceOf[Int] - if (dValue < 0) { - return null - } - - lastDValue match { - case Some(last) if last == dValue => - // use the current pattern - case _ => - // construct a new DecimalFormat only if a new dValue - pattern.delete(0, pattern.length) - pattern.append("#,###,###,###,###,###,##0") - - // decimal place - if (dValue > 0) { - pattern.append(".") - - var i = 0 - while (i < dValue) { - i += 1 - pattern.append("0") - } + right.dataType match { + case IntegerType => + val dValue = dObject.asInstanceOf[Int] + if (dValue < 0) { + return null } - lastDValue = Some(dValue) + lastDIntValue match { + case Some(last) if last == dValue => + // use the current pattern + case _ => + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length) + pattern.append(defaultFormat) + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } + } + + lastDIntValue = Some(dValue) - numberFormat.applyLocalizedPattern(pattern.toString) + numberFormat.applyLocalizedPattern(pattern.toString) + } + case StringType => + val dValue = dObject.asInstanceOf[UTF8String].toString + lastDStringValue match { + case Some(last) if last == dValue => + case _ => + pattern.delete(0, pattern.length) + lastDStringValue = Some(dValue) + if (dValue.isEmpty) { + numberFormat.applyLocalizedPattern(defaultFormat) + } else { + numberFormat.applyLocalizedPattern(dValue) + } + } } x.dataType match { @@ -2007,35 +2033,52 @@ case class FormatNumber(x: Expression, d: Expression) // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. val usLocale = "US" - val i = ctx.freshName("i") - val dFormat = ctx.freshName("dFormat") - val lastDValue = - ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") - val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") val numberFormat = ctx.addMutableState(df, "numberFormat", v => s"""$v = new $df("", new $dfs($l.$usLocale));""") - s""" - if ($d >= 0) { - $pattern.delete(0, $pattern.length()); - if ($d != $lastDValue) { - $pattern.append("#,###,###,###,###,###,##0"); - - if ($d > 0) { - $pattern.append("."); - for (int $i = 0; $i < $d; $i++) { - $pattern.append("0"); + right.dataType match { + case IntegerType => + val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") + val i = ctx.freshName("i") + val lastDValue = + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") + s""" + if ($d >= 0) { + $pattern.delete(0, $pattern.length()); + if ($d != $lastDValue) { + $pattern.append("$defaultFormat"); + + if ($d > 0) { + $pattern.append("."); + for (int $i = 0; $i < $d; $i++) { + $pattern.append("0"); + } + } + $lastDValue = $d; + $numberFormat.applyLocalizedPattern($pattern.toString()); } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + } else { + ${ev.value} = null; + ${ev.isNull} = true; } - $lastDValue = $d; - $numberFormat.applyLocalizedPattern($pattern.toString()); - } - ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); - } else { - ${ev.value} = null; - ${ev.isNull} = true; - } - """ + """ + case StringType => + val lastDValue = ctx.addMutableState("String", "lastDValue", v => s"""$v = null;""") + val dValue = ctx.freshName("dValue") + s""" + String $dValue = $d.toString(); + if (!$dValue.equals($lastDValue)) { + $lastDValue = $dValue; + if ($dValue.isEmpty()) { + $numberFormat.applyLocalizedPattern("$defaultFormat"); + } else { + $numberFormat.applyLocalizedPattern($dValue); + } + } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + """ + } }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 78895f1c2f6f5..f957aaa96e98c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -21,7 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp} import org.apache.spark.sql.types._ /** @@ -297,6 +297,37 @@ trait WindowFunction extends Expression { def frame: WindowFrame = UnspecifiedFrame } +/** + * Case objects that describe whether a window function is a SQL window function or a Python + * user-defined window function. + */ +sealed trait WindowFunctionType + +object WindowFunctionType { + case object SQL extends WindowFunctionType + case object Python extends WindowFunctionType + + def functionType(windowExpression: NamedExpression): WindowFunctionType = { + val t = windowExpression.collectFirst { + case _: WindowFunction | _: AggregateFunction => SQL + case udf: PythonUDF if PythonUDF.isWindowPandasUDF(udf) => Python + } + + // Normally a window expression would either have a SQL window function, a SQL + // aggregate function or a python window UDF. However, sometimes the optimizer will replace + // the window function if the value of the window function can be predetermined. + // For example, for query: + // + // select count(NULL) over () from values 1.0, 2.0, 3.0 T(a) + // + // The window function will be replaced by expression literal(0) + // To handle this case, if a window expression doesn't have a regular window function, we + // consider its type to be SQL as literal(0) is also a SQL expression. + t.getOrElse(SQL) + } +} + + /** * An offset window function is a window function that returns the value of the input column offset * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with @@ -342,7 +373,10 @@ abstract class OffsetWindowFunction override lazy val frame: WindowFrame = { val boundary = direction match { case Ascending => offset - case Descending => UnaryMinus(offset) + case Descending => UnaryMinus(offset) match { + case e: Expression if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) + case o => o + } } SpecifiedWindowFrame(RowFrame, boundary, boundary) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index 025a388aacaa5..3e8e6db1dbd22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -18,10 +18,14 @@ package org.apache.spark.sql.catalyst.json import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} +import java.nio.channels.Channels +import java.nio.charset.Charset import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.hadoop.io.Text +import sun.nio.cs.StreamDecoder +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.unsafe.types.UTF8String private[sql] object CreateJacksonParser extends Serializable { @@ -43,7 +47,48 @@ private[sql] object CreateJacksonParser extends Serializable { jsonFactory.createParser(record.getBytes, 0, record.getLength) } - def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { - jsonFactory.createParser(record) + // Jackson parsers can be ranked according to their performance: + // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser + // but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically + // by checking leading bytes of the array. + // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected + // automatically by analyzing first bytes of the input stream. + // 3. Reader based parser. This is the slowest parser used here but it allows to create + // a reader with specific encoding. + // The method creates a reader for an array with given encoding and sets size of internal + // decoding buffer according to size of input array. + private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = { + val bais = new ByteArrayInputStream(in, 0, length) + val byteChannel = Channels.newChannel(bais) + val decodingBufferSize = Math.min(length, 8192) + val decoder = Charset.forName(enc).newDecoder() + + StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize) + } + + def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = { + val sd = getStreamDecoder(enc, record.getBytes, record.getLength) + jsonFactory.createParser(sd) + } + + def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = { + jsonFactory.createParser(is) + } + + def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = { + jsonFactory.createParser(new InputStreamReader(is, enc)) + } + + def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = { + val ba = row.getBinary(0) + + jsonFactory.createParser(ba, 0, ba.length) + } + + def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = { + val binary = row.getBinary(0) + val sd = getStreamDecoder(enc, binary, binary.length) + + jsonFactory.createParser(sd) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5c9adc3332bc0..47eeb70e00427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.nio.charset.StandardCharsets +import java.nio.charset.{Charset, StandardCharsets} import java.util.{Locale, TimeZone} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util._ * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ private[sql] class JSONOptions( - @transient private val parameters: CaseInsensitiveMap[String], + @transient val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { @@ -73,6 +73,9 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + // Whether to ignore column of all null values or empty array/struct during schema inference + val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) @@ -86,14 +89,28 @@ private[sql] class JSONOptions( val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + /** + * A string between two consecutive JSON records. + */ val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => require(sep.nonEmpty, "'lineSep' cannot be an empty string.") sep } - // Note that the option 'lineSep' uses a different default value in read and write. - val lineSeparatorInRead: Option[Array[Byte]] = - lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) - // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + + protected def checkedEncoding(enc: String): String = enc + + /** + * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. + * If the encoding is not specified (None) in read, it will be detected automatically + * when the multiLine option is set to `true`. If encoding is not specified in write, + * UTF-8 is used by default. + */ + val encoding: Option[String] = parameters.get("encoding") + .orElse(parameters.get("charset")).map(checkedEncoding) + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.getOrElse("UTF-8")) + } val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") /** Sets config options on a Jackson [[JsonFactory]]. */ @@ -108,3 +125,46 @@ private[sql] class JSONOptions( factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, allowUnquotedControlChars) } } + +private[sql] class JSONOptionsInRead( + @transient override val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) + extends JSONOptions(parameters, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) { + + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } + + protected override def checkedEncoding(enc: String): String = { + val isBlacklisted = JSONOptionsInRead.blacklist.contains(Charset.forName(enc)) + require(multiLine || !isBlacklisted, + s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: + |Blacklist: ${JSONOptionsInRead.blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } +} + +private[sql] object JSONOptionsInRead { + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val blacklist = Seq( + Charset.forName("UTF-16"), + Charset.forName("UTF-32") + ) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 7f6956994f31f..c3a4ca8f64bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayOutputStream, CharConversionException} import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -36,7 +36,7 @@ import org.apache.spark.util.Utils * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. */ class JacksonParser( - schema: StructType, + schema: DataType, val options: JSONOptions) extends Logging { import JacksonUtils._ @@ -57,7 +57,14 @@ class JacksonParser( * to a value according to a desired schema. This is a wrapper for the method * `makeConverter()` to handle a row wrapped with an array. */ - private def makeRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { + private def makeRootConverter(dt: DataType): JsonParser => Seq[InternalRow] = { + dt match { + case st: StructType => makeStructRootConverter(st) + case mt: MapType => makeMapRootConverter(mt) + } + } + + private def makeStructRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { val elementConverter = makeConverter(st) val fieldConverters = st.map(_.dataType).map(makeConverter).toArray (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, st) { @@ -87,6 +94,13 @@ class JacksonParser( } } + private def makeMapRootConverter(mt: MapType): JsonParser => Seq[InternalRow] = { + val fieldConverter = makeConverter(mt.valueType) + (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, mt) { + case START_OBJECT => Seq(InternalRow(convertMap(parser, fieldConverter))) + } + } + /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. @@ -361,6 +375,14 @@ class JacksonParser( // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => recordLiteral(record), () => None, e) + case e: CharConversionException if options.encoding.isEmpty => + val msg = + """JSON parser cannot handle a character in its input. + |Specifying encoding as an input option explicitly might help to resolve the issue. + |""".stripMargin + e.getMessage + val wrappedCharException = new CharConversionException(msg) + wrappedCharException.initCause(e) + throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala similarity index 87% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index a270a6451d5dd..491ca005877f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import java.util.Comparator @@ -25,7 +25,6 @@ import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil -import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,8 +44,9 @@ private[sql] object JsonInferSchema { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - // perform schema inference on each row and merge afterwards - val rootType = json.mapPartitions { iter => + // In each RDD partition, perform schema inference on each row and merge afterwards. + val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode) + val mergedTypesFromPartitions = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => @@ -66,11 +66,15 @@ private[sql] object JsonInferSchema { s"Parse Mode: ${FailFastMode.name}.", e) } } - } - }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode)) + }.reduceOption(typeMerger).toIterator + } + + // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because + // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have + // active SparkSession and `SQLConf.get` may point to the wrong configs. + val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) - canonicalizeType(rootType) match { + canonicalizeType(rootType, configOptions) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep @@ -98,7 +102,7 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { + def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType @@ -176,33 +180,33 @@ private[sql] object JsonInferSchema { } /** - * Convert NullType to StringType and remove StructTypes with no fields + * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields, + * drops NullTypes or converts them to StringType based on provided options. */ - private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { - case at @ ArrayType(elementType, _) => - for { - canonicalType <- canonicalizeType(elementType) - } yield { - at.copy(canonicalType) - } + private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { + case at: ArrayType => + canonicalizeType(at.elementType, options) + .map(t => at.copy(elementType = t)) case StructType(fields) => - val canonicalFields: Array[StructField] = for { - field <- fields - if field.name.length > 0 - canonicalType <- canonicalizeType(field.dataType) - } yield { - field.copy(dataType = canonicalType) + val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f => + canonicalizeType(f.dataType, options) + .map(t => f.copy(dataType = t)) } - - if (canonicalFields.length > 0) { - Some(StructType(canonicalFields)) + // SPARK-8093: empty structs should be deleted + if (canonicalFields.isEmpty) { + None } else { - // per SPARK-8093: empty structs should be deleted + Some(StructType(canonicalFields)) + } + + case NullType => + if (options.dropFieldIfAllNull) { None + } else { + Some(StringType) } - case NullType => Some(StringType) case other => Some(other) } @@ -329,8 +333,8 @@ private[sql] object JsonInferSchema { ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in - // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when - // the given `DecimalType` is not capable of the given `IntegralType`. + // `findTightestCommonType`. Both cases below will be executed only when the given + // `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => compatibleType(DecimalType.forType(t1), t2) case (t1: DecimalType, t2: IntegralType) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 913354e4df0e6..2cc27d82f7d20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -526,9 +526,10 @@ object ColumnPruning extends Rule[LogicalPlan] { /** * The Project before Filter is not necessary but conflict with PushPredicatesThroughProject, - * so remove it. + * so remove it. Since the Projects have been added top-down, we need to remove in bottom-up + * order, otherwise lower Projects can be missed. */ - private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform { + private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp { case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(child.outputSet) => p1.copy(child = f.copy(child = child)) @@ -621,12 +622,15 @@ object CollapseRepartition extends Rule[LogicalPlan] { /** * Collapse Adjacent Window Expression. * - If the partition specs and order specs are the same and the window expression are - * independent, collapse into the parent. + * independent and are of the same window function type, collapse into the parent. */ object CollapseWindow extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) - if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty => + if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty && + // This assumes Window contains the same type of window expressions. This is ensured + // by ExtractWindowFunctions. + WindowFunctionType.functionType(we1.head) == WindowFunctionType.functionType(we2.head) => w1.copy(windowExpressions = we2 ++ we1, child = grandChild) } } @@ -637,13 +641,11 @@ object CollapseWindow extends Rule[LogicalPlan] { * constraints. These filters are currently inserted to the existing conditions in the Filter * operators and on either side of Join operators. * - * In addition, for left/right outer joins, infer predicate from the preserved side of the Join - * operator and push the inferred filter over to the null-supplying side. For example, if the - * preserved side has constraints of the form 'a > 5' and the join condition is 'a = b', in - * which 'b' is an attribute from the null-supplying side, a [[Filter]] operator of 'b > 5' will - * be applied to the null-supplying side. + * Note: While this optimization is applicable to a lot of types of join, it primarily benefits + * Inner and LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { +object InferFiltersFromConstraints extends Rule[LogicalPlan] + with PredicateHelper with ConstraintHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (SQLConf.get.constraintPropagationEnabled) { @@ -664,53 +666,52 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe } case join @ Join(left, right, joinType, conditionOpt) => - // Only consider constraints that can be pushed down completely to either the left or the - // right child - val constraints = join.allConstraints.filter { c => - c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) - } - // Remove those constraints that are already enforced by either the left or the right child - val additionalConstraints = constraints -- (left.constraints ++ right.constraints) - val newConditionOpt = conditionOpt match { - case Some(condition) => - val newFilters = additionalConstraints -- splitConjunctivePredicates(condition) - if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else conditionOpt - case None => - additionalConstraints.reduceOption(And) - } - // Infer filter for left/right outer joins - val newLeftOpt = joinType match { - case RightOuter if newConditionOpt.isDefined => - val inferredConstraints = left.getRelevantConstraints( - left.constraints - .union(right.constraints) - .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) - val newFilters = inferredConstraints - .filterNot(left.constraints.contains) - .reduceLeftOption(And) - newFilters.map(Filter(_, left)) - case _ => None - } - val newRightOpt = joinType match { - case LeftOuter if newConditionOpt.isDefined => - val inferredConstraints = right.getRelevantConstraints( - right.constraints - .union(left.constraints) - .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) - val newFilters = inferredConstraints - .filterNot(right.constraints.contains) - .reduceLeftOption(And) - newFilters.map(Filter(_, right)) - case _ => None - } + joinType match { + // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an + // inner join, it just drops the right side in the final output. + case _: InnerLike | LeftSemi => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + val newRight = inferNewFilter(right, allConstraints) + join.copy(left = newLeft, right = newRight) - if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt)) - || newLeftOpt.isDefined || newRightOpt.isDefined) { - Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), joinType, newConditionOpt) - } else { - join + // For right outer join, we can only infer additional filters for left side. + case RightOuter => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + join.copy(left = newLeft) + + // For left join, we can only infer additional filters for right side. + case LeftOuter | LeftAnti => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newRight = inferNewFilter(right, allConstraints) + join.copy(right = newRight) + + case _ => join } } + + private def getAllConstraints( + left: LogicalPlan, + right: LogicalPlan, + conditionOpt: Option[Expression]): Set[Expression] = { + val baseConstraints = left.constraints.union(right.constraints) + .union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet) + baseConstraints.union(inferAdditionalConstraints(baseConstraints)) + } + + private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = { + val newPredicates = constraints + .union(constructIsNotNullConstraints(constraints, plan.output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic + } -- plan.constraints + if (newPredicates.isEmpty) { + plan + } else { + Filter(newPredicates.reduce(And), plan) + } + } } /** @@ -770,12 +771,29 @@ object EliminateSorts extends Rule[LogicalPlan] { } /** - * Removes Sort operation if the child is already sorted + * Removes redundant Sort operation. This can happen: + * 1) if the child is already sorted + * 2) if there is another Sort operator separated by 0...n Project/Filter operators */ object RemoveRedundantSorts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => child + case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) + } + + def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match { + case Sort(_, _, child) => recursiveRemoveSort(child) + case other if canEliminateSort(other) => + other.withNewChildren(other.children.map(recursiveRemoveSort)) + case _ => plan + } + + def canEliminateSort(plan: LogicalPlan): Boolean = plan match { + case p: Project => p.projectList.forall(_.deterministic) + case f: Filter => f.condition.deterministic + case _: ResolvedHint => true + case _ => false } } @@ -1168,12 +1186,14 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) if isCartesianProduct(j) => throw new AnalysisException( - s"""Detected cartesian product for ${j.joinType.sql} join between logical plans + s"""Detected implicit cartesian product for ${j.joinType.sql} join between logical plans |${left.treeString(false).trim} |and |${right.treeString(false).trim} |Join condition is missing or trivial. - |Use the CROSS JOIN syntax to allow cartesian products between these relations.""" + |Either: use the CROSS JOIN syntax to allow cartesian products between these + |relations, or: enable implicit cartesian products by setting the configuration + |variable spark.sql.crossJoin.enabled=true""" .stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1c0b7bd806801..1d363b8146e3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 709db6d8bec7d..de89e17e51f1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -116,15 +116,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // (a1,a2,...) = (b1,b2,...) // to // (a1=b1 OR isnull(a1=b1)) AND (a2=b2 OR isnull(a2=b2)) AND ... - val joinConds = splitConjunctivePredicates(joinCond.get) + val baseJoinConds = splitConjunctivePredicates(joinCond.get) + val nullAwareJoinConds = baseJoinConds.map(c => Or(c, IsNull(c))) // After that, add back the correlated join predicate(s) in the subquery // Example: // SELECT ... FROM A WHERE A.A1 NOT IN (SELECT B.B1 FROM B WHERE B.B2 = A.A2 AND B.B3 > 1) // will have the final conditions in the LEFT ANTI as - // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) - val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And) + // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1 + val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And) // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftAnti, Option(pairs))) + dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond))) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bdc357d54a878..f398b479dc273 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -503,7 +503,14 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val join = right.optionalMap(left)(Join(_, _, Inner, None)) withJoinRelations(join, relation) } - ctx.lateralView.asScala.foldLeft(from)(withGenerate) + if (ctx.pivotClause() != null) { + if (!ctx.lateralView.isEmpty) { + throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx) + } + withPivot(ctx.pivotClause, from) + } else { + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } } /** @@ -614,6 +621,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging plan } + /** + * Add a [[Pivot]] to a logical plan. + */ + private def withPivot( + ctx: PivotClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val aggregates = Option(ctx.aggregates).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText) + val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply) + Pivot(None, pivotColumn, pivotValues, aggregates, query) + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ @@ -1185,6 +1206,34 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging new StringLocate(expression(ctx.substr), expression(ctx.str)) } + /** + * Create a Extract expression. + */ + override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) { + ctx.field.getText.toUpperCase(Locale.ROOT) match { + case "YEAR" => + Year(expression(ctx.source)) + case "QUARTER" => + Quarter(expression(ctx.source)) + case "MONTH" => + Month(expression(ctx.source)) + case "WEEK" => + WeekOfYear(expression(ctx.source)) + case "DAY" => + DayOfMonth(expression(ctx.source)) + case "DAYOFWEEK" => + DayOfWeek(expression(ctx.source)) + case "HOUR" => + Hour(expression(ctx.source)) + case "MINUTE" => + Minute(expression(ctx.source)) + case "SECOND" => + Second(expression(ctx.source)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } + /** * Create a (windowed) Function expression. */ @@ -1458,7 +1507,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case "TIMESTAMP" => Literal(Timestamp.valueOf(value)) case "X" => - val padding = if (value.length % 2 == 1) "0" else "" + val padding = if (value.length % 2 != 0) "0" else "" Literal(DatatypeConverter.parseHexBinary(padding + value)) case other => throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 626f905707191..84be677e438a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.planning import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -215,7 +216,7 @@ object PhysicalAggregation { case agg: AggregateExpression if !equivalentAggregateExpressions.addExpr(agg) => agg case udf: PythonUDF - if PythonUDF.isGroupAggPandasUDF(udf) && + if PythonUDF.isGroupedAggPandasUDF(udf) && !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -245,7 +246,7 @@ object PhysicalAggregation { equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute // Similar to AggregateExpression - case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) => + case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) => equivalentAggregateExpressions.getEquivalentExprs(ue).headOption .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => @@ -268,3 +269,40 @@ object PhysicalAggregation { case _ => None } } + +/** + * An extractor used when planning physical execution of a window. This extractor outputs + * the window function type of the logical window. + * + * The input logical window must contain same type of window functions, which is ensured by + * the rule ExtractWindowExpressions in the analyzer. + */ +object PhysicalWindow { + // windowFunctionType, windowExpression, partitionSpec, orderSpec, child + private type ReturnType = + (WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan) + + def unapply(a: Any): Option[ReturnType] = a match { + case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) => + + // The window expression should not be empty here, otherwise it's a bug. + if (windowExpressions.isEmpty) { + throw new AnalysisException(s"Window expression is empty in $expr") + } + + val windowFunctionType = windowExpressions.map(WindowFunctionType.functionType) + .reduceLeft { (t1: WindowFunctionType, t2: WindowFunctionType) => + if (t1 != t2) { + // We shouldn't have different window function type here, otherwise it's a bug. + throw new AnalysisException( + s"Found different window function type in $windowExpressions") + } else { + t1 + } + } + + Some((windowFunctionType, windowExpressions, partitionSpec, orderSpec, child)) + + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 64cb8c726772f..4b4722b0b2117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -27,8 +27,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT /** * The active config object within the current scope. - * Note that if you want to refer config values during execution, you have to capture them - * in Driver and use the captured values in Executors. * See [[SQLConf.get]] for more information. */ def conf: SQLConf = SQLConf.get @@ -119,6 +117,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case Some(value) => Some(recursiveTransform(value)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs + case stream: Stream[_] => stream.map(recursiveTransform).force case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other case null => null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index b05508db786ad..8c4828a4cef23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -43,6 +44,12 @@ object LocalRelation { } } +/** + * Logical plan node for scanning data from a local collection. + * + * @param data The local collection holding the data. It doesn't need to be sent to executors + * and then doesn't need to be serializable. + */ case class LocalRelation( output: Seq[Attribute], data: Seq[InternalRow] = Nil, @@ -71,7 +78,7 @@ case class LocalRelation( } override def computeStats(): Statistics = - Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) + Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 42034403d6d03..c486ad700f362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -78,7 +78,7 @@ abstract class LogicalPlan schema.map { field => resolve(field.name :: Nil, resolver).map { case a: AttributeReference => a - case other => sys.error(s"can not handle nested schema yet... plan $this") + case _ => sys.error(s"can not handle nested schema yet... plan $this") }.getOrElse { throw new AnalysisException( s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]") @@ -86,6 +86,10 @@ abstract class LogicalPlan } } + private[this] lazy val childAttributes = AttributeSeq(children.flatMap(_.output)) + + private[this] lazy val outputAttributes = AttributeSeq(output) + /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as @@ -94,7 +98,7 @@ abstract class LogicalPlan def resolveChildren( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, children.flatMap(_.output), resolver) + childAttributes.resolve(nameParts, resolver) /** * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this @@ -104,7 +108,7 @@ abstract class LogicalPlan def resolve( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, output, resolver) + outputAttributes.resolve(nameParts, resolver) /** * Given an attribute name, split it to name parts by dot, but @@ -114,105 +118,7 @@ abstract class LogicalPlan def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver) - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * This assumes `name` has multiple parts, where the 1st part is a qualifier - * (i.e. table name, alias, or subquery alias). - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsTableColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - assert(nameParts.length > 1) - if (attribute.qualifier.exists(resolver(_, nameParts.head))) { - // At least one qualifier matches. See if remaining parts match. - val remainingParts = nameParts.tail - resolveAsColumn(remainingParts, resolver, attribute) - } else { - None - } - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier. - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - if (resolver(attribute.name, nameParts.head)) { - Option((attribute.withName(nameParts.head), nameParts.tail.toList)) - } else { - None - } - } - - /** Performs attribute resolution given a name and a sequence of possible attributes. */ - protected def resolve( - nameParts: Seq[String], - input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { - - // A sequence of possible candidate matches. - // Each candidate is a tuple. The first element is a resolved attribute, followed by a list - // of parts that are to be resolved. - // For example, consider an example where "a" is the table name, "b" is the column name, - // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", - // and the second element will be List("c"). - var candidates: Seq[(Attribute, List[String])] = { - // If the name has 2 or more parts, try to resolve it as `table.column` first. - if (nameParts.length > 1) { - input.flatMap { option => - resolveAsTableColumn(nameParts, resolver, option) - } - } else { - Seq.empty - } - } - - // If none of attributes match `table.column` pattern, we try to resolve it as a column. - if (candidates.isEmpty) { - candidates = input.flatMap { candidate => - resolveAsColumn(nameParts, resolver, candidate) - } - } - - def name = UnresolvedAttribute(nameParts).name - - candidates.distinct match { - // One match, no nested fields, use it. - case Seq((a, Nil)) => Some(a) - - // One match, but we also need to extract the requested nested field. - case Seq((a, nestedFields)) => - // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and aliased it with the last part of the name. - // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final - // expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => - ExtractValue(expr, Literal(fieldName), resolver)) - Some(Alias(fieldExprs, nestedFields.last)()) - - // No matches. - case Seq() => - logTrace(s"Could not find $name in ${input.mkString(", ")}") - None - - // More than one match. - case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") - throw new AnalysisException( - s"Reference '$name' is ambiguous, could be: $referenceNames.") - } + outputAttributes.resolve(UnresolvedAttribute.parseAttributeName(name), resolver) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index a29f3d29236c7..cc352c59dff80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -20,29 +20,28 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -trait QueryPlanConstraints { self: LogicalPlan => +trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan => /** - * An [[ExpressionSet]] that contains an additional set of constraints, such as equality - * constraints and `isNotNull` constraints, etc. + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. */ - lazy val allConstraints: ExpressionSet = { + lazy val constraints: ExpressionSet = { if (conf.constraintPropagationEnabled) { - ExpressionSet(validConstraints - .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints))) + ExpressionSet( + validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints, output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + } + ) } else { ExpressionSet(Set.empty) } } - /** - * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For - * example, if this set contains the expression `a = 2` then that expression is guaranteed to - * evaluate to `true` for all rows produced. - */ - lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter(selfReferenceOnly)) - /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then @@ -52,30 +51,42 @@ trait QueryPlanConstraints { self: LogicalPlan => * See [[Canonicalize]] for more details. */ protected def validConstraints: Set[Expression] = Set.empty +} + +trait ConstraintHelper { /** - * Returns an [[ExpressionSet]] that contains an additional set of constraints, such as - * equality constraints and `isNotNull` constraints, etc., and that only contains references - * to this [[LogicalPlan]] node. + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5`. */ - def getRelevantConstraints(constraints: Set[Expression]): ExpressionSet = { - val allRelevantConstraints = - if (conf.constraintPropagationEnabled) { - constraints - .union(inferAdditionalConstraints(constraints)) - .union(constructIsNotNullConstraints(constraints)) - } else { - constraints - } - ExpressionSet(allRelevantConstraints.filter(selfReferenceOnly)) + def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferredConstraints = Set.empty[Expression] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + val candidateConstraints = constraints - eq + inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) + inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) + case _ => // No inference + } + inferredConstraints -- constraints } + private def replaceConstraints( + constraints: Set[Expression], + source: Expression, + destination: Attribute): Set[Expression] = constraints.map(_ transform { + case e: Expression if e.semanticEquals(source) => destination + }) + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this * returns a constraint of the form `isNotNull(a)` */ - private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + def constructIsNotNullConstraints( + constraints: Set[Expression], + output: Seq[Attribute]): Set[Expression] = { // First, we propagate constraints from the null intolerant expressions. var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) @@ -111,32 +122,4 @@ trait QueryPlanConstraints { self: LogicalPlan => case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] } - - /** - * Infers an additional set of constraints from a given set of equality constraints. - * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5`. - */ - private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - var inferredConstraints = Set.empty[Expression] - constraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = constraints - eq - inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) - inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) - case _ => // No inference - } - inferredConstraints -- constraints - } - - private def replaceConstraints( - constraints: Set[Expression], - source: Expression, - destination: Attribute): Set[Expression] = constraints.map(_ transform { - case e: Expression if e.semanticEquals(source) => destination - }) - - private def selfReferenceOnly(e: Expression): Boolean = { - e.references.nonEmpty && e.references.subsetOf(outputSet) && e.deterministic - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 10df504795430..3bf32ef7884e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -686,17 +686,34 @@ case class GroupingSets( override lazy val resolved: Boolean = false } +/** + * A constructor for creating a pivot, which will later be converted to a [[Project]] + * or an [[Aggregate]] during the query analysis. + * + * @param groupByExprsOpt A sequence of group by expressions. This field should be None if coming + * from SQL, in which group by expressions are not explicitly specified. + * @param pivotColumn The pivot column. + * @param pivotValues A sequence of values for the pivot column. + * @param aggregates The aggregation expressions, each with or without an alias. + * @param child Child operator + */ case class Pivot( - groupByExprs: Seq[NamedExpression], + groupByExprsOpt: Option[Seq[NamedExpression]], pivotColumn: Expression, pivotValues: Seq[Literal], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { - case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) - case _ => pivotValues.flatMap{ value => - aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + override lazy val resolved = false // Pivot will be replaced after being resolved. + override def output: Seq[Attribute] = { + val pivotAgg = aggregates match { + case agg :: Nil => + pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => + pivotValues.flatMap { value => + aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + } } + groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 0f147f0ffb135..211a2a0717371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode @@ -25,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.{DecimalType, _} - object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ @@ -73,13 +71,12 @@ object EstimationUtils { AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) } - def getOutputSize( + def getSizePerRow( attributes: Seq[Attribute], - outputRowCount: BigInt, attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. - val sizePerRow = 8 + attributes.map { attr => + 8 + attributes.map { attr => if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) { attr.dataType match { case StringType => @@ -92,10 +89,15 @@ object EstimationUtils { attr.dataType.defaultSize } }.sum + } + def getOutputSize( + attributes: Seq[Attribute], + outputRowCount: BigInt, + attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // Output size can't be zero, or sizeInBytes of BinaryNode will also be zero // (simple computation of statistics returns product of children). - if (outputRowCount > 0) outputRowCount * sizePerRow else 1 + if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1 } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 0538c9d88584b..5a3eeefaedb18 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -395,6 +395,10 @@ case class FilterEstimation(plan: Filter) extends Logging { // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => + if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { + return Some(0.0) + } + val statsInterval = ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval] val validQuerySet = hSet.filter { v => @@ -418,6 +422,10 @@ case class FilterEstimation(plan: Filter) extends Logging { // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => + if (ndv.toDouble == 0) { + return Some(0.0) + } + newNdv = ndv.min(BigInt(hSet.size)) if (update) { val newStats = colStat.copy(distinctCount = Some(newNdv), nullCount = Some(0)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 85f67c7d66075..ee43f9126386b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { private def visitUnaryNode(p: UnaryNode): Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. - val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8 - val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8 + val childRowSize = EstimationUtils.getSizePerRow(p.child.output) + val outputRowSize = EstimationUtils.getSizePerRow(p.output) // Assume there will be the same number of rows as child has. var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize if (sizeInBytes == 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4d9a9925fe3ff..cc1a5e835d9cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -99,16 +99,19 @@ case class ClusteredDistribution( * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the * number of partitions, this distribution strictly requires which partition the tuple should be in. */ -case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { +case class HashClusteredDistribution( + expressions: Seq[Expression], + requiredNumPartitions: Option[Int] = None) extends Distribution { require( expressions != Nil, - "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + + "The expressions for hash of a HashClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") - override def requiredNumPartitions: Option[Int] = None - override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") HashPartitioning(expressions, numPartitions) } } @@ -163,11 +166,22 @@ trait Partitioning { * i.e. the current dataset does not need to be re-partitioned for the `required` * Distribution (it is possible that tuples within a partition need to be reorganized). * + * A [[Partitioning]] can never satisfy a [[Distribution]] if its `numPartitions` does't match + * [[Distribution.requiredNumPartitions]]. + */ + final def satisfies(required: Distribution): Boolean = { + required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required) + } + + /** + * The actual method that defines whether this [[Partitioning]] can satisfy the given + * [[Distribution]], after the `numPartitions` check. + * * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if - * the [[Partitioning]] only have one partition. Implementations can overwrite this method with - * special logic. + * the [[Partitioning]] only have one partition. Implementations can also overwrite this method + * with special logic. */ - def satisfies(required: Distribution): Boolean = required match { + protected def satisfies0(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case AllTuples => numPartitions == 1 case _ => false @@ -186,9 +200,8 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning case object SinglePartition extends Partitioning { val numPartitions = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1 case _ => true } } @@ -205,16 +218,15 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { required match { case h: HashClusteredDistribution => expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { case (l, r) => l.semanticEquals(r) } - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case ClusteredDistribution(requiredClustering, _) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } } @@ -246,15 +258,14 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { required match { case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case ClusteredDistribution(requiredClustering, _) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } } @@ -295,7 +306,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) * Returns true if any `partitioning` of this collection satisfies the given * [[Distribution]]. */ - override def satisfies(required: Distribution): Boolean = + override def satisfies0(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) override def toString: String = { @@ -310,7 +321,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { override val numPartitions: Int = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case BroadcastDistribution(m) if m == mode => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9c7d47f99ee10..becfa8d982213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -199,44 +199,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { var changed = false val remainingNewChildren = newChildren.toBuffer val remainingOldChildren = children.toBuffer + def mapTreeNode(node: TreeNode[_]): TreeNode[_] = { + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + } + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) + case nonChild: AnyRef => nonChild + case null => null + } val newArgs = mapProductIterator { case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. - case s: Seq[_] => s.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - } - case m: Map[_, _] => m.mapValues { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - }.view.force // `mapValues` is lazy and we need to force it to materialize - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } + case s: Stream[_] => + // Stream is lazy so we need to force materialization + s.map(mapChild).force + case s: Seq[_] => + s.map(mapChild) + case m: Map[_, _] => + // `mapValues` is lazy and we need to force it to materialize + m.mapValues(mapChild).view.force + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) case nonChild: AnyRef => nonChild case null => null } @@ -301,6 +290,37 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def mapChildren(f: BaseType => BaseType): BaseType = { if (children.nonEmpty) { var changed = false + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { + changed = true + (newChild1, newChild2) + } else { + tuple + } + case other => other + } + val newArgs = mapProductIterator { case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) @@ -330,36 +350,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case other => other }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs - case args: Traversable[_] => args.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = f(arg.asInstanceOf[BaseType]) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = if (containsChild(arg1)) { - f(arg1.asInstanceOf[BaseType]) - } else { - arg1.asInstanceOf[BaseType] - } - - val newChild2 = if (containsChild(arg2)) { - f(arg2.asInstanceOf[BaseType]) - } else { - arg2.asInstanceOf[BaseType] - } - - if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { - changed = true - (newChild1, newChild2) - } else { - tuple - } - case other => other - } + case args: Stream[_] => args.map(mapChild).force // Force materialization on stream + case args: Traversable[_] => args.map(mapChild) case nonChild: AnyRef => nonChild case null => null } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index fa69b8af62c85..80f15053005ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -296,10 +296,28 @@ object DateTimeUtils { * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` */ def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = { - stringToTimestamp(s, defaultTimeZone()) + stringToTimestamp(s, defaultTimeZone(), rejectTzInString = false) } def stringToTimestamp(s: UTF8String, timeZone: TimeZone): Option[SQLTimestamp] = { + stringToTimestamp(s, timeZone, rejectTzInString = false) + } + + /** + * Converts a timestamp string to microseconds from the unix epoch, w.r.t. the given timezone. + * Returns None if the input string is not a valid timestamp format. + * + * @param s the input timestamp string. + * @param timeZone the timezone of the timestamp string, will be ignored if the timestamp string + * already contains timezone information and `forceTimezone` is false. + * @param rejectTzInString if true, rejects timezone in the input string, i.e., if the + * timestamp string contains timezone, like `2000-10-10 00:00:00+00:00`, + * return None. + */ + def stringToTimestamp( + s: UTF8String, + timeZone: TimeZone, + rejectTzInString: Boolean): Option[SQLTimestamp] = { if (s == null) { return None } @@ -417,6 +435,8 @@ object DateTimeUtils { return None } + if (tz.isDefined && rejectTzInString) return None + val c = if (tz.isEmpty) { Calendar.getInstance(timeZone) } else { @@ -865,29 +885,19 @@ object DateTimeUtils { /** * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. - * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). - * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. - */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp): Double = { - monthsBetween(time1, time2, defaultTimeZone()) - } - - /** - * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. + * microseconds since 1.1.1970. If time1 is later than time2, the result is positive. * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). + * If time1 and time2 are on the same day of month, or both are the last day of month, + * returns, time of day will be ignored. * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. + * Otherwise, the difference is calculated based on 31 days per month. + * The result is rounded to 8 decimal places if `roundOff` is set to true. */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp, timeZone: TimeZone): Double = { + def monthsBetween( + time1: SQLTimestamp, + time2: SQLTimestamp, + roundOff: Boolean, + timeZone: TimeZone): Double = { val millis1 = time1 / 1000L val millis2 = time2 / 1000L val date1 = millisToDays(millis1, timeZone) @@ -898,16 +908,25 @@ object DateTimeUtils { val months1 = year1 * 12 + monthInYear1 val months2 = year2 * 12 + monthInYear2 + val monthDiff = (months1 - months2).toDouble + if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) { - return (months1 - months2).toDouble + return monthDiff + } + // using milliseconds can cause precision loss with more than 8 digits + // we follow Hive's implementation which uses seconds + val secondsInDay1 = (millis1 - daysToMillis(date1, timeZone)) / 1000L + val secondsInDay2 = (millis2 - daysToMillis(date2, timeZone)) / 1000L + val secondsDiff = (dayInMonth1 - dayInMonth2) * SECONDS_PER_DAY + secondsInDay1 - secondsInDay2 + // 2678400D is the number of seconds in 31 days + // every month is considered to be 31 days long in this function + val diff = monthDiff + secondsDiff / 2678400D + if (roundOff) { + // rounding to 8 digits + math.round(diff * 1e8) / 1e8 + } else { + diff } - // milliseconds is enough for 8 digits precision on the right side - val timeInDay1 = millis1 - daysToMillis(date1, timeZone) - val timeInDay2 = millis2 - daysToMillis(date2, timeZone) - val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY - val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 - // rounding to 8 digits - math.round(diff * 1e8) / 1e8 } // Thursday = 0 since 1970/Jan/01 => Thursday diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 9e39ed9c3a778..83ad08d8e1758 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -122,7 +122,7 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { return false } - case _ => if (o1 != o2) { + case _ => if (!o1.equals(o2)) { return false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index b013add9c9778..3190e511e2cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -40,12 +40,14 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats * See the G-K article for more details. * @param count the count of all the elements *inserted in the sampled buffer* * (excluding the head buffer) + * @param compressed whether the statistics have been compressed */ class QuantileSummaries( val compressThreshold: Int, val relativeError: Double, val sampled: Array[Stats] = Array.empty, - val count: Long = 0L) extends Serializable { + val count: Long = 0L, + var compressed: Boolean = false) extends Serializable { // a buffer of latest samples seen so far private val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty @@ -60,6 +62,7 @@ class QuantileSummaries( */ def insert(x: Double): QuantileSummaries = { headSampled += x + compressed = false if (headSampled.size >= defaultHeadSize) { val result = this.withHeadBufferInserted if (result.sampled.length >= compressThreshold) { @@ -135,11 +138,11 @@ class QuantileSummaries( assert(inserted.count == count + headSampled.size) val compressed = compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count) - new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count) + new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count, true) } private def shallowCopy: QuantileSummaries = { - new QuantileSummaries(compressThreshold, relativeError, sampled, count) + new QuantileSummaries(compressThreshold, relativeError, sampled, count, compressed) } /** @@ -163,7 +166,7 @@ class QuantileSummaries( val res = (sampled ++ other.sampled).sortBy(_.value) val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count) new QuantileSummaries( - other.compressThreshold, other.relativeError, comp, other.count + count) + other.compressThreshold, other.relativeError, comp, other.count + count, true) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala new file mode 100644 index 0000000000000..19f67236c8979 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{Map => JMap} + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} + +/** + * A readonly SQLConf that will be created by tasks running at the executor side. It reads the + * configs from the local properties which are propagated from driver to executors. + */ +class ReadOnlySQLConf(context: TaskContext) extends SQLConf { + + @transient override val settings: JMap[String, String] = { + context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + } + + @transient override protected val reader: ConfigReader = { + new ConfigReader(new TaskContextConfigProvider(context)) + } + + override protected def setConfWithCheck(key: String, value: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(key: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(entry: ConfigEntry[_]): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clear(): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clone(): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } + + override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } +} + +class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { + override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3729bd5293eca..07d33fa7d52ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,11 +27,12 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.util.Utils @@ -95,7 +96,9 @@ object SQLConf { /** * Returns the active config object within the current scope. If there is an active SparkSession, - * the proper SQLConf associated with the thread's session is used. + * the proper SQLConf associated with the thread's active session is used. If it's called from + * tasks in the executor side, a SQLConf will be created from job local properties, which are set + * and propagated from the driver side. * * The way this works is a little bit convoluted, due to the fact that config was added initially * only for physical plans (and as a result not in sql/catalyst module). @@ -107,7 +110,22 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (TaskContext.get != null) { + new ReadOnlySQLConf(TaskContext.get()) + } else { + if (Utils.isTesting && SparkContext.getActive.isDefined) { + // DAGScheduler event loop thread does not have an active SparkSession, the `confGetter` + // will return `fallbackConf` which is unexpected. Here we prevent it from happening. + val schedulerEventLoopThread = + SparkContext.getActive.get.dagScheduler.eventProcessLoop.eventThread + if (schedulerEventLoopThread.getId == Thread.currentThread().getId) { + throw new RuntimeException("Cannot get SQLConf inside scheduler event loop thread.") + } + } + confGetter.get()() + } + } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -345,7 +363,7 @@ object SQLConf { "snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo", "lz4", "brotli", "zstd")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") @@ -360,6 +378,35 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.timestamp") + .doc("If true, enables Parquet filter push-down optimization for Timestamp. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is " + + "enabled and Timestamp stored as TIMESTAMP_MICROS or TIMESTAMP_MILLIS type.") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.string.startsWith") + .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + + val PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD = + buildConf("spark.sql.parquet.pushdown.inFilterThreshold") + .doc("The maximum number of values to filter push-down optimization for IN predicate. " + + "Large threshold won't necessarily provide much better performance. " + + "The experiment argued that 300 is the limit threshold. " + + "By setting this value to 0 this feature can be disabled. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .intConf + .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") + .createWithDefault(10) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -377,7 +424,7 @@ object SQLConf { .doc("The output committer class used by Parquet. The specified class needs to be a " + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + "of org.apache.parquet.hadoop.ParquetOutputCommitter. If it is not, then metadata summaries" + - "will never be created, irrespective of the value of parquet.enable.summary-metadata") + "will never be created, irrespective of the value of parquet.summary.metadata.level") .internal() .stringConf .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") @@ -686,6 +733,17 @@ object SQLConf { .intConf .createWithDefault(100) + val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode") + .doc("This config determines the fallback behavior of several codegen generators " + + "during tests. `FALLBACK` means trying codegen first and then fallbacking to " + + "interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " + + "`NO_CODEGEN` skips codegen and goes interpreted path always. Note that " + + "this config works only for tests.") + .internal() + .stringConf + .checkValues(CodegenObjectFactoryMode.values.map(_.toString)) + .createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString) + val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback") .internal() .doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" + @@ -838,6 +896,21 @@ object SQLConf { .stringConf .createWithDefault("org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol") + val STREAMING_MULTIPLE_WATERMARK_POLICY = + buildConf("spark.sql.streaming.multipleWatermarkPolicy") + .doc("Policy to calculate the global watermark value when there are multiple watermark " + + "operators in a streaming query. The default value is 'min' which chooses " + + "the minimum watermark reported across multiple operators. Other alternative value is" + + "'max' which chooses the maximum across multiple operators." + + "Note: This configuration cannot be changed between query restarts from the same " + + "checkpoint location.") + .stringConf + .checkValue( + str => Set("min", "max").contains(str.toLowerCase), + "Invalid value for 'spark.sql.streaming.multipleWatermarkPolicy'. " + + "Valid values are 'min' and 'max'") + .createWithDefault("min") // must be same as MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD = buildConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold") .internal() @@ -919,6 +992,14 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10000L) + val STREAMING_NO_DATA_MICRO_BATCHES_ENABLED = + buildConf("spark.sql.streaming.noDataMicroBatchesEnabled") + .doc( + "Whether streaming micro-batch engine will execute batches without data " + + "for eager state management for stateful streaming queries.") + .booleanConf + .createWithDefault(true) + val STREAMING_METRICS_ENABLED = buildConf("spark.sql.streaming.metricsEnabled") .doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.") @@ -1124,6 +1205,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION = + buildConf("spark.sql.execution.pandas.groupedMap.assignColumnsByPosition") + .internal() + .doc("When true, a grouped map Pandas UDF will assign columns from the returned " + + "Pandas DataFrame based on position, regardless of column label type. When false, " + + "columns will be looked up by name if labeled with a string and fallback to use " + + "position if not. This configuration will be deprecated in future releases.") + .booleanConf + .createWithDefault(false) + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -1147,8 +1238,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SQL_OPTIONS_REDACTION_PATTERN = + buildConf("spark.sql.redaction.options.regex") + .doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " + + "information. The values of options whose names that match this regex will be redacted " + + "in the explain output. This redaction is applied on top of the global redaction " + + s"configuration defined by ${SECRET_REDACTION_PATTERN.key}.") + .regexConf + .createWithDefault("(?i)url".r) + val SQL_STRING_REDACTION_PATTERN = - ConfigBuilder("spark.sql.redaction.string.regex") + buildConf("spark.sql.redaction.string.regex") .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + "information. When this regex matches a string part, that string part is replaced by a " + "dummy value. This is currently used to redact the output of SQL explain commands. " + @@ -1208,6 +1308,13 @@ object SQLConf { .stringConf .createWithDefault("") + val REJECT_TIMEZONE_IN_STRING = buildConf("spark.sql.function.rejectTimezoneInString") + .internal() + .doc("If true, `to_utc_timestamp` and `from_utc_timestamp` return null if the input string " + + "contains a timezone part, e.g. `2000-10-10 00:00:00+00:00`.") + .booleanConf + .createWithDefault(true) + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } @@ -1238,6 +1345,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val TOP_K_SORT_FALLBACK_THRESHOLD = + buildConf("spark.sql.execution.topKSortFallbackThreshold") + .internal() + .doc("In SQL queries with a SORT followed by a LIMIT like " + + "'SELECT x FROM t ORDER BY y LIMIT m', if m is under this threshold, do a top-K sort" + + " in memory, otherwise do a global sort which spills to disk if necessary.") + .intConf + .createWithDefault(Int.MaxValue) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1245,6 +1361,42 @@ object SQLConf { object Replaced { val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" } + + val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") + .internal() + .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + + "Other column values can be ignored during parsing even if they are malformed.") + .booleanConf + .createWithDefault(true) + + val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull") + .doc("If it is set to true, size of null returns -1. This behavior was inherited from Hive. " + + "The size function returns null for null input if the flag is disabled.") + .booleanConf + .createWithDefault(true) + + val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled") + .doc("Enables eager evaluation or not. When true, the top K rows of Dataset will be " + + "displayed if and only if the REPL supports the eager evaluation. Currently, the " + + "eager evaluation is only supported in PySpark. For the notebooks like Jupyter, " + + "the HTML table (generated by _repr_html_) will be returned. For plain Python REPL, " + + "the returned outputs are formatted like dataframe.show().") + .booleanConf + .createWithDefault(false) + + val REPL_EAGER_EVAL_MAX_NUM_ROWS = buildConf("spark.sql.repl.eagerEval.maxNumRows") + .doc("The max number of rows that are returned by eager evaluation. This only takes " + + "effect when spark.sql.repl.eagerEval.enabled is set to true. The valid range of this " + + "config is from 0 to (Int.MaxValue - 1), so the invalid config like negative and " + + "greater than (Int.MaxValue - 1) will be normalized to 0 and (Int.MaxValue - 1).") + .intConf + .createWithDefault(20) + + val REPL_EAGER_EVAL_TRUNCATE = buildConf("spark.sql.repl.eagerEval.truncate") + .doc("The max number of characters for each cell that is returned by eager evaluation. " + + "This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.") + .intConf + .createWithDefault(20) } /** @@ -1259,17 +1411,11 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient private val reader = new ConfigReader(settings) + @transient protected val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1306,6 +1452,9 @@ class SQLConf extends Serializable with Logging { def streamingNoDataProgressEventInterval: Long = getConf(STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL) + def streamingNoDataMicroBatchesEnabled: Boolean = + getConf(STREAMING_NO_DATA_MICRO_BATCHES_ENABLED) + def streamingMetricsEnabled: Boolean = getConf(STREAMING_METRICS_ENABLED) def streamingProgressRetention: Int = getConf(STREAMING_PROGRESS_RETENTION) @@ -1354,6 +1503,14 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def parquetFilterPushDownTimestamp: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED) + + def parquetFilterPushDownStringStartWith: Boolean = + getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) + + def parquetFilterPushDownInFilterThreshold: Int = + getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) @@ -1392,6 +1549,8 @@ class SQLConf extends Serializable with Logging { def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) + def codegenCacheMaxEntries: Int = getConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) @@ -1402,10 +1561,12 @@ class SQLConf extends Serializable with Logging { def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) - def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) + def stringRedactionPattern: Option[Regex] = getConf(SQL_STRING_REDACTION_PATTERN) def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) + def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. @@ -1579,6 +1740,9 @@ class SQLConf extends Serializable with Logging { def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE) + def pandasGroupedMapAssignColumnssByPosition: Boolean = + getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) @@ -1603,6 +1767,16 @@ class SQLConf extends Serializable with Logging { def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + + def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL) + + def isReplEagerEvalEnabled: Boolean = getConf(SQLConf.REPL_EAGER_EVAL_ENABLED) + + def replEagerEvalMaxNumRows: Int = getConf(SQLConf.REPL_EAGER_EVAL_MAX_NUM_ROWS) + + def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ @@ -1709,6 +1883,17 @@ class SQLConf extends Serializable with Logging { }.toSeq } + /** + * Redacts the given option map according to the description of SQL_OPTIONS_REDACTION_PATTERN. + */ + def redactOptions(options: Map[String, String]): Map[String, String] = { + val regexes = Seq( + getConf(SQL_OPTIONS_REDACTION_PATTERN), + SECRET_REDACTION_PATTERN.readFrom(reader)) + + regexes.foldLeft(options.toSeq) { case (opts, r) => Utils.redact(Some(r), opts) }.toMap + } + /** * Return whether a given key is set in this [[SQLConf]]. */ @@ -1716,7 +1901,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - private def setConfWithCheck(key: String, value: String): Unit = { + protected def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } @@ -1748,4 +1933,8 @@ class SQLConf extends Serializable with Logging { } cloned } + + def isModifiable(key: String): Boolean = { + sqlConfEntries.containsKey(key) && !staticConfKeys.contains(key) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index fe0ad39c29025..384b1917a1f79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -66,6 +66,14 @@ object StaticSQLConf { .checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative") .createWithDefault(1000) + val CODEGEN_CACHE_MAX_ENTRIES = buildStaticConf("spark.sql.codegen.cache.maxEntries") + .internal() + .doc("When nonzero, enable caching of generated classes for operators and expressions. " + + "All jobs share the cache that can use up to the specified number for generated classes.") + .intConf + .checkValue(maxEntries => maxEntries >= 0, "The maximum must not be negative") + .createWithDefault(100) + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. val DEBUG_MODE = buildStaticConf("spark.sql.debug") @@ -96,6 +104,14 @@ object StaticSQLConf { .toSequence .createOptional + val STREAMING_QUERY_LISTENERS = buildStaticConf("spark.sql.streaming.streamingQueryListeners") + .doc("List of class names implementing StreamingQueryListener that will be automatically " + + "added to newly created sessions. The classes should have either a no-arg constructor, " + + "or a constructor that expects a SparkConf argument.") + .stringConf + .toSequence + .createOptional + val UI_RETAINED_EXECUTIONS = buildStaticConf("spark.sql.ui.retainedExecutions") .doc("Number of executions to retain in the Spark UI.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0bef11659fc9e..fd40741cfb5f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.types import java.util.Locale +import scala.util.control.NonFatal + import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ @@ -26,6 +28,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -110,6 +113,14 @@ abstract class DataType extends AbstractDataType { @InterfaceStability.Stable object DataType { + def fromDDL(ddl: String): DataType = { + try { + CatalystSqlParser.parseDataType(ddl) + } catch { + case NonFatal(_) => CatalystSqlParser.parseTableSchema(ddl) + } + } + def fromJson(json: String): DataType = parseDataType(parse(json)) private val nonDecimalNameToType = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index ef3b67c0d48d0..dbf51c398fa47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -161,13 +161,17 @@ object DecimalType extends AbstractDataType { * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. */ private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { - // Assumptions: + // Assumption: assert(precision >= scale) - assert(scale >= 0) if (precision <= MAX_PRECISION) { // Adjustment only needed when we exceed max precision DecimalType(precision, scale) + } else if (scale < 0) { + // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision + // loss since we would cause a loss of digits in the integer part. + // In this case, we are likely to meet an overflow. + DecimalType(MAX_PRECISION, scale) } else { // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. val intDigits = precision - scale diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index f3702ec92b425..89452ee05cff3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class CatalystTypeConvertersSuite extends SparkFunSuite { @@ -94,4 +95,56 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray) == doubleGenericArray) } + + test("converting a wrong value to the struct type") { + val structType = new StructType().add("f1", IntegerType) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(structType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to struct")) + } + + test("converting a wrong value to the map type") { + val mapType = MapType(StringType, IntegerType, false) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(mapType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to a map type with key " + + "type (string) and value type (int)")) + } + + test("converting a wrong value to the array type") { + val arrayType = ArrayType(IntegerType, true) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(arrayType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to an array of int")) + } + + test("converting a wrong value to the decimal type") { + val decimalType = DecimalType(10, 0) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(decimalType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to decimal(10,0)")) + } + + test("converting a wrong value to the string type") { + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(StringType)(0.1) + } + assert(exception.getMessage.contains("The value (0.1) of the type " + + "(java.lang.Double) cannot be converted to the string type")) + } + + test("SPARK-24571: convert Char to String") { + val chr: Char = 'X' + val converter = CatalystTypeConverters.createToCatalystConverter(StringType) + val expected = UTF8String.fromString("X") + assert(converter(chr) === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index b47b8adfe5d55..39228102682b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -41,34 +41,127 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning (with nullSafe = true) is the output partitioning") { - // Cases which do not need an exchange between two data properties. + test("UnspecifiedDistribution and AllTuples") { + // except `BroadcastPartitioning`, all other partitioning can satisfy UnspecifiedDistribution checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), + UnknownPartitioning(-1), UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + RoundRobinPartitioning(10), + UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + SinglePartition, + UnspecifiedDistribution, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + UnspecifiedDistribution, + false) + + // except `BroadcastPartitioning`, all other partitioning can satisfy AllTuples if they have + // only one partition. + checkSatisfied( + UnknownPartitioning(1), + AllTuples, + true) + + checkSatisfied( + UnknownPartitioning(10), + AllTuples, + false) + + checkSatisfied( + RoundRobinPartitioning(1), + AllTuples, + true) + + checkSatisfied( + RoundRobinPartitioning(10), + AllTuples, + false) + + checkSatisfied( + SinglePartition, + AllTuples, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 1), + AllTuples, true) + checkSatisfied( + HashPartitioning(Seq('a), 10), + AllTuples, + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 1), + AllTuples, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + AllTuples, + false) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + AllTuples, + false) + } + + test("SinglePartition is the output partitioning") { + // SinglePartition can satisfy all the distributions except `BroadcastDistribution` checkSatisfied( SinglePartition, ClusteredDistribution(Seq('a, 'b, 'c)), true) + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( SinglePartition, OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), true) - // Cases which need an exchange between two data properties. + checkSatisfied( + SinglePartition, + BroadcastDistribution(IdentityBroadcastMode), + false) + } + + test("HashPartitioning is the output partitioning") { + // HashPartitioning can satisfy ClusteredDistribution iff its hash expressions are a subset of + // the required clustering expressions. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), ClusteredDistribution(Seq('b, 'c)), @@ -79,37 +172,43 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('d, 'e)), false) + // HashPartitioning can satisfy HashClusteredDistribution iff its hash expressions are exactly + // same with the required hash clustering expressions. checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), - AllTuples, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('c, 'b, 'a), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + checkSatisfied( + HashPartitioning(Seq('a, 'b), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), + false) + + // HashPartitioning cannot satisfy OrderedDistribution checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 1), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - false) + false) // TODO: this can be relaxed. - // TODO: We should check functional dependencies - /* checkSatisfied( - ClusteredDistribution(Seq('b)), - ClusteredDistribution(Seq('b + 1)), - true) - */ + HashPartitioning(Seq('b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) } test("RangePartitioning is the output partitioning") { - // Cases which do not need an exchange between two data properties. - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - UnspecifiedDistribution, - true) - + // RangePartitioning can satisfy OrderedDistribution iff its ordering is a prefix + // of the required ordering, or the required ordering is a prefix of its ordering. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), @@ -125,6 +224,27 @@ class DistributionSuite extends SparkFunSuite { OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)), true) + // TODO: We can have an optimization to first sort the dataset + // by a.asc and then sort b, and c in a partition. This optimization + // should tradeoff the benefit of a less number of Exchange operators + // and the parallelism. + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('b.asc, 'a.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'd.desc)), + false) + + // RangePartitioning can satisfy ClusteredDistribution iff its ordering expressions are a subset + // of the required clustering expressions. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), ClusteredDistribution(Seq('a, 'b, 'c)), @@ -140,34 +260,47 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('b, 'c, 'a, 'd)), true) - // Cases which need an exchange between two data properties. - // TODO: We can have an optimization to first sort the dataset - // by a.asc and then sort b, and c in a partition. This optimization - // should tradeoff the benefit of a less number of Exchange operators - // and the parallelism. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + ClusteredDistribution(Seq('a, 'b)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('b.asc, 'a.asc)), + ClusteredDistribution(Seq('c, 'd)), false) + // RangePartitioning cannot satisfy HashClusteredDistribution checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b)), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + } + test("Partitioning.numPartitions must match Distribution.requiredNumPartitions to satisfy it") { checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'd)), + SinglePartition, + ClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + HashClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - AllTuples, + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index cd8579584eada..bbcdf6c1b8481 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,6 +21,7 @@ import java.util.TimeZone import org.scalatest.Matchers +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -557,4 +558,21 @@ class AnalysisSuite extends AnalysisTest with Matchers { SubqueryAlias("tbl", testRelation))) assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) } + + test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { + val pythonUdf = PythonUDF("pyUDF", null, + StructType(Seq(StructField("a", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true) + val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes + val project = Project(Seq(UnresolvedAttribute("a")), testRelation) + val flatMapGroupsInPandas = FlatMapGroupsInPandas( + Seq(UnresolvedAttribute("a")), pythonUdf, output, project) + val left = SubqueryAlias("temp0", flatMapGroupsInPandas) + val right = SubqueryAlias("temp1", flatMapGroupsInPandas) + val join = Join(left, right, Inner, None) + assertAnalysisSuccess( + Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index c86dc18dfa680..bd87ca6017e99 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -272,6 +272,15 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { } } + test("SPARK-24468: operations on decimals with negative scale") { + val a = AttributeReference("a", DecimalType(3, -10))() + val b = AttributeReference("b", DecimalType(1, -1))() + val c = AttributeReference("c", DecimalType(35, 1))() + checkType(Multiply(a, b), DecimalType(5, -11)) + checkType(Multiply(a, c), DecimalType(38, -9)) + checkType(Multiply(b, c), DecimalType(37, 0)) + } + /** strength reduction for integer/decimal comparisons */ def ruleTest(initial: Expression, transformed: Expression): Unit = { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala new file mode 100644 index 0000000000000..cea0f2a9cbc97 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.net.URI + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf + +class LookupFunctionsSuite extends PlanTest { + + test("SPARK-23486: the functionExists for the Persistent function check") { + val externalCatalog = new CustomInMemoryCatalog + val conf = new SQLConf() + val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf) + val analyzer = { + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) + new Analyzer(catalog, conf) + } + + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) + val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false) + val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) + val plan = Project( + Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(), + Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(), + Alias(unresolvedRegisteredFunc, "call5")()), + table("TaBlE")) + analyzer.LookupFunctions.apply(plan) + + assert(externalCatalog.getFunctionExistsCalledTimes == 1) + assert(analyzer.LookupFunctions.normalizeFuncName + (unresolvedPersistentFunc.name).database == Some("default")) + } + + test("SPARK-23486: the functionExists for the Registered function check") { + val externalCatalog = new InMemoryCatalog + val conf = new SQLConf() + val customerFunctionReg = new CustomerFunctionRegistry + val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf) + val analyzer = { + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) + new Analyzer(catalog, conf) + } + + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) + val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) + val plan = Project( + Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()), + table("TaBlE")) + analyzer.LookupFunctions.apply(plan) + + assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2) + assert(analyzer.LookupFunctions.normalizeFuncName + (unresolvedRegisteredFunc.name).database == Some("default")) + } +} + +class CustomerFunctionRegistry extends SimpleFunctionRegistry { + + private var isRegisteredFunctionCalledTimes: Int = 0; + + override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized { + isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1 + true + } + + def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes +} + +class CustomInMemoryCatalog extends InMemoryCatalog { + + private var functionExistsCalledTimes: Int = 0 + + override def functionExists(db: String, funcName: String): Boolean = synchronized { + functionExistsCalledTimes = functionExistsCalledTimes + 1 + true + } + + def getFunctionExistsCalledTimes: Int = functionExistsCalledTimes +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index fd6a3121663ed..8cc5a23779a2a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -54,8 +54,9 @@ class TypeCoercionSuite extends AnalysisTest { // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable. + // Note: StructType* is castable when all the internal child types are castable according to the table. // Note: ArrayType* is castable when the element type is castable according to the table. + // Note: MapType* is castable when both the key type and the value type are castable according to the table. // scalastyle:on line.size.limit private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { @@ -396,7 +397,7 @@ class TypeCoercionSuite extends AnalysisTest { widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), StructType(Seq(StructField("a", DoubleType, nullable = false))), - None) + Some(StructType(Seq(StructField("a", DoubleType, nullable = false))))) widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), @@ -429,21 +430,42 @@ class TypeCoercionSuite extends AnalysisTest { Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))), isSymmetric = false) } + + widenTest( + ArrayType(IntegerType, containsNull = true), + ArrayType(IntegerType, containsNull = false), + Some(ArrayType(IntegerType, containsNull = true))) + + widenTest( + MapType(IntegerType, StringType, valueContainsNull = true), + MapType(IntegerType, StringType, valueContainsNull = false), + Some(MapType(IntegerType, StringType, valueContainsNull = true))) + + widenTest( + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = false), + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = false), nullable = true), + Some(new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true))) } test("wider common type for decimal and array") { def widenTestWithStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected, isSymmetric) } def widenTestWithoutStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType( + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected, isSymmetric) } // Decimal @@ -469,12 +491,108 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(ArrayType(IntegerType), containsNull = false), ArrayType(ArrayType(LongType), containsNull = false), Some(ArrayType(ArrayType(LongType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(MapType(IntegerType, FloatType), containsNull = false), + ArrayType(MapType(LongType, DoubleType), containsNull = false), + Some(ArrayType(MapType(LongType, DoubleType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(new StructType().add("num", ShortType), containsNull = false), + ArrayType(new StructType().add("num", LongType), containsNull = false), + Some(ArrayType(new StructType().add("num", LongType), containsNull = false))) + + // MapType + widenTestWithStringPromotion( + MapType(ShortType, TimestampType, valueContainsNull = true), + MapType(DoubleType, StringType, valueContainsNull = false), + Some(MapType(DoubleType, StringType, valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, ArrayType(TimestampType), valueContainsNull = false), + MapType(LongType, ArrayType(StringType), valueContainsNull = true), + Some(MapType(LongType, ArrayType(StringType), valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, MapType(ShortType, TimestampType), valueContainsNull = false), + MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false), + Some(MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(IntegerType, new StructType().add("num", ShortType), valueContainsNull = false), + MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false), + Some(MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false))) + + // StructType + widenTestWithStringPromotion( + new StructType() + .add("num", ShortType, nullable = true).add("ts", StringType, nullable = false), + new StructType() + .add("num", DoubleType, nullable = false).add("ts", TimestampType, nullable = true), + Some(new StructType() + .add("num", DoubleType, nullable = true).add("ts", StringType, nullable = true))) + widenTestWithStringPromotion( + new StructType() + .add("arr", ArrayType(ShortType, containsNull = false), nullable = false), + new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false), + Some(new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false))) + widenTestWithStringPromotion( + new StructType() + .add("map", MapType(ShortType, TimestampType, valueContainsNull = true), nullable = false), + new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = false), nullable = false), + Some(new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = true), nullable = false))) + + widenTestWithStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + } // Without string promotion widenTestWithoutStringPromotion(IntegerType, StringType, None) widenTestWithoutStringPromotion(StringType, TimestampType, None) widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None) widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None) + widenTestWithoutStringPromotion( + MapType(LongType, IntegerType), MapType(StringType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, LongType), MapType(IntegerType, StringType), None) + widenTestWithoutStringPromotion( + MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + None) // String promotion widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) @@ -483,6 +601,30 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType))) widenTestWithStringPromotion( ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType))) + widenTestWithStringPromotion( + MapType(LongType, IntegerType), + MapType(StringType, IntegerType), + Some(MapType(StringType, IntegerType))) + widenTestWithStringPromotion( + MapType(IntegerType, LongType), + MapType(IntegerType, StringType), + Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + MapType(StringType, IntegerType), + MapType(TimestampType, IntegerType), + Some(MapType(StringType, IntegerType))) + widenTestWithStringPromotion( + MapType(IntegerType, StringType), + MapType(IntegerType, TimestampType), + Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + Some(new StructType().add("a", StringType))) + widenTestWithStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + Some(new StructType().add("a", StringType))) } private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { @@ -506,11 +648,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -518,11 +660,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } @@ -805,7 +947,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) @@ -1257,7 +1399,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 60d1351fda264..cb487c8893541 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -621,6 +621,13 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = Seq("monotonically_increasing_id")) + assertSupportedForContinuousProcessing( + "TypedFilter", TypedFilter( + null, + null, + null, + null, + new TestStreamingRelationV2(attribute)), OutputMode.Append()) /* ======================================================================================= @@ -771,6 +778,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite { } } + /** Assert that the logical plan is supported for continuous procsssing mode */ + def assertSupportedForContinuousProcessing( + name: String, + plan: LogicalPlan, + outputMode: OutputMode): Unit = { + test(s"continuous processing - $name: supported") { + UnsupportedOperationChecker.checkForContinuous(plan, outputMode) + } + } + /** * Assert that the logical plan is not supported inside a streaming plan. * @@ -840,4 +857,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite { def this(attribute: Attribute) = this(Seq(attribute)) override def isStreaming: Boolean = true } + + case class TestStreamingRelationV2(output: Seq[Attribute]) extends LeafNode { + def this(attribute: Attribute) = this(Seq(attribute)) + override def isStreaming: Boolean = true + override def nodeName: String = "StreamingRelationV2" + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala index 1acbe34d9a075..2fcaeca34db3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -36,7 +36,7 @@ class ExternalCatalogEventSuite extends SparkFunSuite { private def testWithCatalog( name: String)( f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) { - val catalog = newCatalog + val catalog = new ExternalCatalogWithListener(newCatalog) val recorder = mutable.Buffer.empty[ExternalCatalogEvent] catalog.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 6abab0073cca3..50496a0410528 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1114,11 +1114,13 @@ abstract class SessionCatalogSuite extends AnalysisTest { // And for hive serde table, hive metastore will set some values(e.g.transient_lastDdlTime) // in table's parameters and storage's properties, here we also ignore them. val actualPartsNormalize = actualParts.map(p => - p.copy(parameters = Map.empty, storage = p.storage.copy( + p.copy(parameters = Map.empty, createTime = -1, lastAccessTime = -1, + storage = p.storage.copy( properties = Map.empty, locationUri = None, serde = None))).toSet val expectedPartsNormalize = expectedParts.map(p => - p.copy(parameters = Map.empty, storage = p.storage.copy( + p.copy(parameters = Map.empty, createTime = -1, lastAccessTime = -1, + storage = p.storage.copy( properties = Map.empty, locationUri = None, serde = None))).toSet actualPartsNormalize == expectedPartsNormalize @@ -1215,6 +1217,42 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } + test("isRegisteredFunction") { + withBasicCatalog { catalog => + // Returns false when the function does not register + assert(!catalog.isRegisteredFunction(FunctionIdentifier("temp1"))) + + // Returns true when the function does register + val tempFunc1 = (e: Seq[Expression]) => e.head + catalog.registerFunction(newFunc("iff", None), overrideIfExists = false, + functionBuilder = Some(tempFunc1) ) + assert(catalog.isRegisteredFunction(FunctionIdentifier("iff"))) + + // Returns false when using the createFunction + catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false) + assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum"))) + assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum", Some("db2")))) + } + } + + test("isPersistentFunction") { + withBasicCatalog { catalog => + // Returns false when the function does not register + assert(!catalog.isPersistentFunction(FunctionIdentifier("temp2"))) + + // Returns false when the function does register + val tempFunc2 = (e: Seq[Expression]) => e.head + catalog.registerFunction(newFunc("iff", None), overrideIfExists = false, + functionBuilder = Some(tempFunc2)) + assert(!catalog.isPersistentFunction(FunctionIdentifier("iff"))) + + // Return true when using the createFunction + catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false) + assert(catalog.isPersistentFunction(FunctionIdentifier("sum", Some("db2")))) + assert(!catalog.isPersistentFunction(FunctionIdentifier("db2.sum"))) + } + } + test("drop function") { withBasicCatalog { catalog => assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala new file mode 100644 index 0000000000000..28e6940f3cca3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.Range + +class CanonicalizeSuite extends SparkFunSuite { + + test("SPARK-24276: IN expression with different order are semantically equal") { + val range = Range(1, 1, 1, 1) + val idAttr = range.output.head + + val in1 = In(idAttr, Seq(Literal(1), Literal(2))) + val in2 = In(idAttr, Seq(Literal(2), Literal(1))) + val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3))) + + assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) + assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) + + assert(range.where(in1).sameResult(range.where(in2))) + assert(!range.where(in1).sameResult(range.where(in3))) + + val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + CreateArray(Seq(Literal(2), Literal(1))))) + val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))), + CreateArray(Seq(Literal(1), Literal(2))))) + val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + CreateArray(Seq(Literal(3), Literal(1))))) + + assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) + assert(arrays1.canonicalized.semanticHash() != arrays3.canonicalized.semanticHash()) + + assert(range.where(arrays1).sameResult(range.where(arrays2))) + assert(!range.where(arrays1).sameResult(range.where(arrays3))) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala new file mode 100644 index 0000000000000..531ca9a87370a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.PlanTestBase +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, LongType} + +class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase { + + test("UnsafeProjection with codegen factory mode") { + val input = Seq(LongType, IntegerType) + .zipWithIndex.map(x => BoundReference(x._2, x._1, true)) + + val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + val obj = UnsafeProjection.createObject(input) + assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection")) + } + + val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { + val obj = UnsafeProjection.createObject(input) + assert(obj.isInstanceOf[InterpretedUnsafeProjection]) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 43c5dda2e4a48..85d6a1befed6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -17,12 +17,21 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} +import java.util.TimeZone + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH +import org.apache.spark.unsafe.types.CalendarInterval class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("Array and Map Size") { + def testSize(sizeOfNull: Any): Unit = { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) @@ -39,8 +48,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Size(m1), 0) checkEvaluation(Size(m2), 1) - checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1) - checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1) + checkEvaluation( + Size(Literal.create(null, MapType(StringType, StringType))), + expected = sizeOfNull) + checkEvaluation( + Size(Literal.create(null, ArrayType(StringType))), + expected = sizeOfNull) + } + + test("Array and Map Size - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSize(sizeOfNull = -1) + } + } + + test("Array and Map Size") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSize(sizeOfNull = null) + } } test("MapKeys/MapValues") { @@ -56,33 +81,262 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapValues(m2), null) } + test("MapEntries") { + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys/values + val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType)) + val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType)) + val mi2 = Literal.create(null, MapType(IntegerType, IntegerType)) + + checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2))) + checkEvaluation(MapEntries(mi1), Seq.empty) + checkEvaluation(MapEntries(mi2), null) + + // Non-primitive-type keys/values + val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType)) + val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType)) + val ms2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null))) + checkEvaluation(MapEntries(ms1), Seq.empty) + checkEvaluation(MapEntries(ms2), null) + } + + test("Map Concat") { + val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, + valueContainsNull = false)) + val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, + valueContainsNull = false)) + val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) + val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) + val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) + val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2), + MapType(ArrayType(IntegerType), IntegerType)) + val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4), + MapType(ArrayType(IntegerType), IntegerType)) + val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val mNull = Literal.create(null, MapType(StringType, StringType)) + + // overlapping maps + checkEvaluation(MapConcat(Seq(m0, m1)), + ( + Array("a", "b", "c", "a"), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // maps with no overlap + checkEvaluation(MapConcat(Seq(m0, m2)), + Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + + // 3 maps + checkEvaluation(MapConcat(Seq(m0, m1, m2)), + ( + Array("a", "b", "c", "a", "d", "e"), // keys + Array("1", "2", "3", "4", "4", "5") // values + ) + ) + + // null reference values + checkEvaluation(MapConcat(Seq(m3, m4)), + ( + Array("a", "b", "a", "c"), // keys + Array("1", "2", null, "3") // values + ) + ) + + // null primitive values + checkEvaluation(MapConcat(Seq(m5, m6)), + ( + Array("a", "b", "a", "c"), // keys + Array(1, 2, null, 3) // values + ) + ) + + // keys that are primitive + checkEvaluation(MapConcat(Seq(m11, m12)), + ( + Array(1, 2, 3, 4), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // keys that are arrays, with overlap + checkEvaluation(MapConcat(Seq(m7, m8)), + ( + Array(List(1, 2), List(3, 4), List(5, 6), List(1, 2)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // keys that are maps, with overlap + checkEvaluation(MapConcat(Seq(m9, m10)), + ( + Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12), + Map(1 -> 2, 3 -> 4)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // null map + checkEvaluation(MapConcat(Seq(m0, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull, m0)), null) + checkEvaluation(MapConcat(Seq(mNull, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull)), null) + + // single map + checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2")) + + // no map + checkEvaluation(MapConcat(Seq.empty), Map.empty) + + // force split expressions for input in generated code + val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e") + val expectedValues = Array.fill(65)(Seq("1", "2")).flatten ++ Array("4", "5") + checkEvaluation(MapConcat( + Seq( + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2 + )), + (expectedKeys, expectedValues)) + + // argument checking + assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, Literal(12))).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType) + assert(MapConcat(Seq(m0, m1)).dataType.valueType == StringType) + assert(!MapConcat(Seq(m0, m1)).dataType.valueContainsNull) + assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType) + assert(MapConcat(Seq.empty).dataType.keyType == StringType) + assert(MapConcat(Seq.empty).dataType.valueType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull) + assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull) + assert(!MapConcat(Seq(m1, m2)).nullable) + assert(MapConcat(Seq(m1, mNull)).nullable) + } + + test("MapFromEntries") { + def arrayType(keyType: DataType, valueType: DataType) : DataType = { + ArrayType( + StructType(Seq( + StructField("a", keyType), + StructField("b", valueType))), + true) + } + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys and values + val aiType = arrayType(IntegerType, IntegerType) + val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType) + val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType) + val ai2 = Literal.create(Seq.empty, aiType) + val ai3 = Literal.create(null, aiType) + val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) + val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) + val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType) + + checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20)) + checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) + checkEvaluation(MapFromEntries(ai2), Map.empty) + checkEvaluation(MapFromEntries(ai3), null) + checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(ai5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(ai6), null) + + // Non-primitive-type keys and values + val asType = arrayType(StringType, StringType) + val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType) + val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType) + val as2 = Literal.create(Seq.empty, asType) + val as3 = Literal.create(null, asType) + val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType) + val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType) + val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType) + + checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb")) + checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) + checkEvaluation(MapFromEntries(as2), Map.empty) + checkEvaluation(MapFromEntries(as3), null) + checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(as5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(as6), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) - val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) + val d1 = new Decimal().set(10) + val d2 = new Decimal().set(100) + val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0))) + val a5 = Literal.create(Seq(null, null), ArrayType(NullType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a2), Seq("a", "b")) checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) + checkEvaluation(new SortArray(a4), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) + checkEvaluation(SortArray(a4, Literal(true)), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) + checkEvaluation(SortArray(a4, Literal(false)), Seq(d2, d1)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) - checkEvaluation(new SortArray(a4), Seq(null, null)) + checkEvaluation(new SortArray(a5), Seq(null, null)) val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) + + val typeAA = ArrayType(ArrayType(IntegerType)) + val aa1 = Array[java.lang.Integer](1, 2) + val aa2 = Array[java.lang.Integer](3, null, 4) + val arrayArray = Literal.create(Seq(aa2, aa1), typeAA) + + checkEvaluation(new SortArray(arrayArray), Seq(aa1, aa2)) + + val typeAAS = ArrayType(ArrayType(StructType(StructField("a", IntegerType) :: Nil))) + val aas1 = Array(create_row(1)) + val aas2 = Array(create_row(2)) + val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS) + + checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2)) + + checkEvaluation(ArraySort(a0), Seq(1, 2, 3)) + checkEvaluation(ArraySort(a1), Seq[Integer]()) + checkEvaluation(ArraySort(a2), Seq("a", "b")) + checkEvaluation(ArraySort(a3), Seq("a", "b", null)) + checkEvaluation(ArraySort(a4), Seq(d1, d2)) + checkEvaluation(ArraySort(a5), Seq(null, null)) + checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) + checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2)) + checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2)) } test("Array contains") { @@ -90,6 +344,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) + val a4 = Literal.create(Seq(create_row(1)), ArrayType(StructType(Seq( + StructField("a", IntegerType, true))))) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) @@ -104,6 +360,252 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayContains(a4, Literal.create(create_row(1), StructType(Seq( + StructField("a", IntegerType, false))))), true) + checkEvaluation(ArrayContains(a4, Literal.create(create_row(0), StructType(Seq( + StructField("a", IntegerType, false))))), false) + + // binary + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val be = Literal.create(Array[Byte](1, 2), BinaryType) + val nullBinary = Literal.create(null, BinaryType) + + checkEvaluation(ArrayContains(b0, be), true) + checkEvaluation(ArrayContains(b1, be), false) + checkEvaluation(ArrayContains(b0, nullBinary), null) + checkEvaluation(ArrayContains(b2, be), null) + checkEvaluation(ArrayContains(b3, be), true) + + // complex data types + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayContains(aa0, aae), true) + checkEvaluation(ArrayContains(aa1, aae), false) + } + + test("ArraysOverlap") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType)) + val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType)) + val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) + val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) + + val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + + checkEvaluation(ArraysOverlap(a0, a1), true) + checkEvaluation(ArraysOverlap(a0, a2), null) + checkEvaluation(ArraysOverlap(a1, a2), true) + checkEvaluation(ArraysOverlap(a1, a3), false) + checkEvaluation(ArraysOverlap(a0, emptyIntArray), false) + checkEvaluation(ArraysOverlap(a2, emptyIntArray), false) + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + + checkEvaluation(ArraysOverlap(a4, a5), true) + checkEvaluation(ArraysOverlap(a4, a6), null) + checkEvaluation(ArraysOverlap(a5, a6), false) + + // null handling + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + checkEvaluation(ArraysOverlap( + emptyIntArray, Literal.create(Seq(null), ArrayType(IntegerType))), false) + checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null) + checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null) + checkEvaluation(ArraysOverlap( + Literal.create(Seq(null), ArrayType(IntegerType)), + Literal.create(Seq(null), ArrayType(IntegerType))), null) + + // arrays of binaries + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + + checkEvaluation(ArraysOverlap(b0, b1), true) + checkEvaluation(ArraysOverlap(b0, b2), false) + + // arrays of complex data types + val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")), + ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")), + ArrayType(ArrayType(StringType))) + val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(ArraysOverlap(aa0, aa1), true) + checkEvaluation(ArraysOverlap(aa0, aa2), false) + + // null handling with complex datatypes + val emptyBinaryArray = Literal.create(Seq.empty[Array[Byte]], ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + checkEvaluation(ArraysOverlap(emptyBinaryArray, b0), false) + checkEvaluation(ArraysOverlap(b0, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(emptyBinaryArray, arrayWithBinaryNull), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, b0), null) + checkEvaluation(ArraysOverlap(b0, arrayWithBinaryNull), null) + } + + test("Slice") { + val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType)) + + checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) + checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) + checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6)) + checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)), + "Unexpected value for length") + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), + "Unexpected value for start") + checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) + checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) + checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), + null) + + checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b")) + checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null)) + checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4)) + } + + test("ArrayJoin") { + def testArrays( + arrays: Seq[Expression], + nullReplacement: Option[Expression], + expected: Seq[String]): Unit = { + assert(arrays.length == expected.length) + arrays.zip(expected).foreach { case (arr, exp) => + checkEvaluation(ArrayJoin(arr, Literal(","), nullReplacement), exp) + } + } + + val arrays = Seq(Literal.create(Seq[String]("a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a", null, "b"), ArrayType(StringType)), + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(Seq[String]("a", "b", null), ArrayType(StringType)), + Literal.create(Seq[String](null, "a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a"), ArrayType(StringType))) + + val withoutNullReplacement = Seq("a,b", "a,b", "", "a,b", "a,b", "a") + val withNullReplacement = Seq("a,b", "a,NULL,b", "NULL", "a,b,NULL", "NULL,a,b", "a") + testArrays(arrays, None, withoutNullReplacement) + testArrays(arrays, Some(Literal("NULL")), withNullReplacement) + + checkEvaluation(ArrayJoin( + Literal.create(null, ArrayType(StringType)), Literal(","), None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(null, StringType), + None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal(","), + Some(Literal.create(null, StringType))), null) + } + + test("ArraysZip") { + val literals = Seq( + Literal.create(Seq(9001, 9002, 9003, null), ArrayType(IntegerType)), + Literal.create(Seq(null, 1L, null, 4L, 11L), ArrayType(LongType)), + Literal.create(Seq(-1, -3, 900, null), ArrayType(IntegerType)), + Literal.create(Seq("a", null, "c"), ArrayType(StringType)), + Literal.create(Seq(null, false, true), ArrayType(BooleanType)), + Literal.create(Seq(1.1, null, 1.3, null), ArrayType(DoubleType)), + Literal.create(Seq(), ArrayType(NullType)), + Literal.create(Seq(null), ArrayType(NullType)), + Literal.create(Seq(192.toByte), ArrayType(ByteType)), + Literal.create( + Seq(Seq(1, 2, 3), null, Seq(4, 5), Seq(1, null, 3)), ArrayType(ArrayType(IntegerType))), + Literal.create(Seq(Array[Byte](1.toByte, 5.toByte)), ArrayType(BinaryType)) + ) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(1))), + List(Row(9001, null), Row(9002, 1L), Row(9003, null), Row(null, 4L), Row(null, 11L))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(2))), + List(Row(9001, -1), Row(9002, -3), Row(9003, 900), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(3))), + List(Row(9001, "a"), Row(9002, null), Row(9003, "c"), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(4))), + List(Row(9001, null), Row(9002, false), Row(9003, true), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(5))), + List(Row(9001, 1.1), Row(9002, null), Row(9003, 1.3), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(6))), + List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(7))), + List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(1), literals(2), literals(3))), + List( + Row(9001, null, -1, "a"), + Row(9002, 1L, -3, null), + Row(9003, null, 900, "c"), + Row(null, 4L, null, null), + Row(null, 11L, null, null))) + + checkEvaluation(ArraysZip(Seq(literals(4), literals(5), literals(6), literals(7), literals(8))), + List( + Row(null, 1.1, null, null, 192.toByte), + Row(false, null, null, null, null), + Row(true, 1.3, null, null, null), + Row(null, null, null, null, null))) + + checkEvaluation(ArraysZip(Seq(literals(9), literals(0))), + List( + Row(List(1, 2, 3), 9001), + Row(null, 9002), + Row(List(4, 5), 9003), + Row(List(1, null, 3), null))) + + checkEvaluation(ArraysZip(Seq(literals(7), literals(10))), + List(Row(null, Array[Byte](1.toByte, 5.toByte)))) + + val longLiteral = + Literal.create((0 to 1000).toSeq, ArrayType(IntegerType)) + + checkEvaluation(ArraysZip(Seq(literals(0), longLiteral)), + List(Row(9001, 0), Row(9002, 1), Row(9003, 2)) ++ + (3 to 1000).map { Row(null, _) }.toList) + + val manyLiterals = (0 to 1000).map { _ => + Literal.create(Seq(1), ArrayType(IntegerType)) + }.toSeq + + val numbers = List( + Row(Seq(9001) ++ (0 to 1000).map { _ => 1 }.toSeq: _*), + Row(Seq(9002) ++ (0 to 1000).map { _ => null }.toSeq: _*), + Row(Seq(9003) ++ (0 to 1000).map { _ => null }.toSeq: _*), + Row(Seq(null) ++ (0 to 1000).map { _ => null }.toSeq: _*)) + checkEvaluation(ArraysZip(Seq(literals(0)) ++ manyLiterals), + List(numbers(0), numbers(1), numbers(2), numbers(3))) + + checkEvaluation(ArraysZip(Seq(literals(0), Literal.create(null, ArrayType(IntegerType)))), null) + checkEvaluation(ArraysZip(Seq()), List()) } test("Array Min") { @@ -126,6 +628,292 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) } + test("Sequence of numbers") { + // test null handling + + checkEvaluation(new Sequence(Literal(null, LongType), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(null, LongType)), null) + checkEvaluation(new Sequence(Literal(null, LongType), Literal(1L), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(null, LongType), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(1L), Literal(null, LongType)), null) + + // test sequence boundaries checking + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)), + EmptyRow, s"Too long sequence: 4294967296. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 1 to 2 by 0") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(2), Literal(1), Literal(0)), EmptyRow, "boundaries: 2 to 1 by 0") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(2), Literal(1), Literal(1)), EmptyRow, "boundaries: 2 to 1 by 1") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, "boundaries: 1 to 2 by -1") + + // test sequence with one element (zero step or equal start and stop) + + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(0)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(1)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(2), Literal(2)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(0), Literal(-2)), Seq(1)) + + // test sequence of different integral types (ascending and descending) + + checkEvaluation(new Sequence(Literal(1L), Literal(3L), Literal(1L)), Seq(1L, 2L, 3L)) + checkEvaluation(new Sequence(Literal(-3), Literal(3), Literal(3)), Seq(-3, 0, 3)) + checkEvaluation( + new Sequence(Literal(3.toShort), Literal(-3.toShort), Literal(-3.toShort)), + Seq(3.toShort, 0.toShort, -3.toShort)) + checkEvaluation( + new Sequence(Literal(-1.toByte), Literal(-3.toByte), Literal(-1.toByte)), + Seq(-1.toByte, -2.toByte, -3.toByte)) + } + + test("Sequence of timestamps") { + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(CalendarInterval.fromString("interval 12 hours"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:01")), + Literal(CalendarInterval.fromString("interval 12 hours"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 12 hours").negate())), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2017-12-31 23:59:59")), + Literal(CalendarInterval.fromString("interval 12 hours").negate())), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-03-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month").negate())), + Seq( + Timestamp.valueOf("2018-03-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-03 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month 1 day").negate())), + Seq( + Timestamp.valueOf("2018-03-03 00:00:00"), + Timestamp.valueOf("2018-02-02 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-31 00:00:00")), + Literal(Timestamp.valueOf("2018-04-30 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Timestamp.valueOf("2018-01-31 00:00:00"), + Timestamp.valueOf("2018-02-28 00:00:00"), + Timestamp.valueOf("2018-03-31 00:00:00"), + Timestamp.valueOf("2018-04-30 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month 1 second"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:01"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:04:06")), + Literal(CalendarInterval.fromString("interval 1 month 2 minutes 3 seconds"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:02:03"), + Timestamp.valueOf("2018-03-01 00:04:06"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2023-01-01 00:00:00")), + Literal(CalendarInterval.fromYearMonthString("1-5"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2022-04-01 00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2022-04-01 00:00:00")), + Literal(Timestamp.valueOf("2017-01-01 00:00:00")), + Literal(CalendarInterval.fromYearMonthString("1-5").negate())), + Seq( + Timestamp.valueOf("2022-04-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2018-01-01 00:00:00.000"))) + } + + test("Sequence on DST boundaries") { + val timeZone = TimeZone.getTimeZone("Europe/Prague") + val dstOffset = timeZone.getDSTSavings + + def noDST(t: Timestamp): Timestamp = new Timestamp(t.getTime - dstOffset) + + DateTimeTestUtils.withDefaultTimeZone(timeZone) { + // Spring time change + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-25 01:30:00")), + Literal(Timestamp.valueOf("2018-03-25 03:30:00")), + Literal(CalendarInterval.fromString("interval 30 minutes"))), + Seq( + Timestamp.valueOf("2018-03-25 01:30:00"), + Timestamp.valueOf("2018-03-25 03:00:00"), + Timestamp.valueOf("2018-03-25 03:30:00"))) + + // Autumn time change + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-10-28 01:30:00")), + Literal(Timestamp.valueOf("2018-10-28 03:30:00")), + Literal(CalendarInterval.fromString("interval 30 minutes"))), + Seq( + Timestamp.valueOf("2018-10-28 01:30:00"), + noDST(Timestamp.valueOf("2018-10-28 02:00:00")), + noDST(Timestamp.valueOf("2018-10-28 02:30:00")), + Timestamp.valueOf("2018-10-28 02:00:00"), + Timestamp.valueOf("2018-10-28 02:30:00"), + Timestamp.valueOf("2018-10-28 03:00:00"), + Timestamp.valueOf("2018-10-28 03:30:00"))) + } + } + + test("Sequence of dates") { + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("UTC")) { + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-01-05")), + Literal(CalendarInterval.fromString("interval 2 days"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-01-03"), + Date.valueOf("2018-01-05"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-03-01")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-02-01"), + Date.valueOf("2018-03-01"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-31")), + Literal(Date.valueOf("2018-04-30")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Date.valueOf("2018-01-31"), + Date.valueOf("2018-02-28"), + Date.valueOf("2018-03-31"), + Date.valueOf("2018-04-30"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2023-01-01")), + Literal(CalendarInterval.fromYearMonthString("1-5"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2019-06-01"), + Date.valueOf("2020-11-01"), + Date.valueOf("2022-04-01"))) + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("1970-01-02")), + Literal(Date.valueOf("1970-01-01")), + Literal(CalendarInterval.fromString("interval 1 day"))), + EmptyRow, "sequence boundaries: 1 to 0 by 1") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("1970-01-01")), + Literal(Date.valueOf("1970-02-01")), + Literal(CalendarInterval.fromString("interval 1 month").negate())), + EmptyRow, + s"sequence boundaries: 0 to 2678400000000 by -${28 * CalendarInterval.MICROS_PER_DAY}") + } + } + + test("Sequence with default step") { + // +/- 1 for integral type + checkEvaluation(new Sequence(Literal(1), Literal(3)), Seq(1, 2, 3)) + checkEvaluation(new Sequence(Literal(3), Literal(1)), Seq(3, 2, 1)) + + // +/- 1 day for timestamps + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-03 00:00:00"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-03 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-03 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00"))), + Seq( + Timestamp.valueOf("2018-01-03 00:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + // +/- 1 day for dates + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-01-03"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-01-02"), + Date.valueOf("2018-01-03"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-03")), + Literal(Date.valueOf("2018-01-01"))), + Seq( + Date.valueOf("2018-01-03"), + Date.valueOf("2018-01-02"), + Date.valueOf("2018-01-01"))) + } + test("Reverse") { // Primitive-type elements val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType)) @@ -190,6 +978,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayPosition(aa0, aae), 1L) + checkEvaluation(ArrayPosition(aa1, aae), 0L) } test("elementAt") { @@ -227,7 +1023,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(null, MapType(StringType, StringType)) - checkEvaluation(ElementAt(m0, Literal(1.0)), null) + assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure) checkEvaluation(ElementAt(m0, Literal("d")), null) @@ -238,6 +1034,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(m0, Literal("c")), null) checkEvaluation(ElementAt(m2, Literal("a")), null) + + // test binary type as keys + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) } test("Concat") { @@ -280,4 +1088,301 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) } + + test("Flatten") { + // Primitive-type test cases + val intArrayType = ArrayType(ArrayType(IntegerType)) + + // Main test cases (primitive type) + val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intArrayType) + val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType) + + checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6)) + checkEvaluation(Flatten(aim2), Seq(1, 2, 3)) + + // Test cases with an empty array (primitive type) + val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), intArrayType) + val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), intArrayType) + val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), intArrayType) + val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), intArrayType) + val aie5 = Literal.create(Seq(Seq.empty), intArrayType) + val aie6 = Literal.create(Seq.empty, intArrayType) + + checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie4), Seq.empty) + checkEvaluation(Flatten(aie5), Seq.empty) + checkEvaluation(Flatten(aie6), Seq.empty) + + // Test cases with null elements (primitive type) + val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), intArrayType) + val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), intArrayType) + val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), intArrayType) + + checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null)) + checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null)) + checkEvaluation(Flatten(ain3), Seq(null, null, null, null)) + + // Test cases with a null array (primitive type) + val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType) + val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType) + val aia3 = Literal.create(Seq(null), intArrayType) + val aia4 = Literal.create(null, intArrayType) + + checkEvaluation(Flatten(aia1), null) + checkEvaluation(Flatten(aia2), null) + checkEvaluation(Flatten(aia3), null) + checkEvaluation(Flatten(aia4), null) + + // Non-primitive-type test cases + val strArrayType = ArrayType(ArrayType(StringType)) + val arrArrayType = ArrayType(ArrayType(ArrayType(StringType))) + + // Main test cases (non-primitive type) + val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType) + val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType) + val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType) + + checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f")) + checkEvaluation(Flatten(asm2), Seq("a", "b")) + checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e"))) + + // Test cases with an empty array (non-primitive type) + val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType) + val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType) + val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType) + val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), strArrayType) + val ase5 = Literal.create(Seq(Seq.empty), strArrayType) + val ase6 = Literal.create(Seq.empty, strArrayType) + + checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase4), Seq.empty) + checkEvaluation(Flatten(ase5), Seq.empty) + checkEvaluation(Flatten(ase6), Seq.empty) + + // Test cases with null elements (non-primitive type) + val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType) + val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType) + val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType) + + checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null)) + checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null)) + checkEvaluation(Flatten(asn3), Seq(null, null, null, null)) + + // Test cases with a null array (non-primitive type) + val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType) + val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType) + val asa3 = Literal.create(Seq(null), strArrayType) + val asa4 = Literal.create(null, strArrayType) + + checkEvaluation(Flatten(asa1), null) + checkEvaluation(Flatten(asa2), null) + checkEvaluation(Flatten(asa3), null) + checkEvaluation(Flatten(asa4), null) + } + + test("ArrayRepeat") { + val intArray = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + val strArray = Literal.create(Seq("hi", "hola"), ArrayType(StringType)) + + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(0)), Seq()) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(-1)), Seq()) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(1)), Seq("hi")) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(2)), Seq("hi", "hi")) + checkEvaluation(ArrayRepeat(Literal(true), Literal(2)), Seq(true, true)) + checkEvaluation(ArrayRepeat(Literal(1), Literal(2)), Seq(1, 1)) + checkEvaluation(ArrayRepeat(Literal(3.2), Literal(2)), Seq(3.2, 3.2)) + checkEvaluation(ArrayRepeat(Literal(null), Literal(2)), Seq[String](null, null)) + checkEvaluation(ArrayRepeat(Literal(null, IntegerType), Literal(2)), Seq[Integer](null, null)) + checkEvaluation(ArrayRepeat(intArray, Literal(2)), Seq(Seq(1, 2), Seq(1, 2))) + checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola"))) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) + } + + test("Array remove") { + val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String](null, "", null, ""), ArrayType(StringType)) + val a3 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a4 = Literal.create(null, ArrayType(StringType)) + val a5 = Literal.create(Seq(1, null, 8, 9, null), ArrayType(IntegerType)) + val a6 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + + checkEvaluation(ArrayRemove(a0, Literal(0)), Seq(1, 2, 3, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(1)), Seq(2, 3, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(2)), Seq(1, 3, 5)) + checkEvaluation(ArrayRemove(a0, Literal(3)), Seq(1, 2, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(5)), Seq(1, 2, 3, 2, 2)) + checkEvaluation(ArrayRemove(a0, Literal(null, IntegerType)), null) + + checkEvaluation(ArrayRemove(a1, Literal("")), Seq("b", "a", "a", "c", "b")) + checkEvaluation(ArrayRemove(a1, Literal("a")), Seq("b", "c", "b")) + checkEvaluation(ArrayRemove(a1, Literal("b")), Seq("a", "a", "c")) + checkEvaluation(ArrayRemove(a1, Literal("c")), Seq("b", "a", "a", "b")) + + checkEvaluation(ArrayRemove(a2, Literal("")), Seq(null, null)) + checkEvaluation(ArrayRemove(a2, Literal(null, StringType)), null) + + checkEvaluation(ArrayRemove(a3, Literal(1)), Seq.empty[Integer]) + + checkEvaluation(ArrayRemove(a4, Literal("a")), null) + + checkEvaluation(ArrayRemove(a5, Literal(9)), Seq(1, null, 8, null)) + checkEvaluation(ArrayRemove(a6, Literal(false)), Seq(true, true)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + + val dataToRemove1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation(ArrayRemove(b0, dataToRemove1), + Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](1, 2))) + checkEvaluation(ArrayRemove(b0, nullBinary), null) + checkEvaluation(ArrayRemove(b1, dataToRemove1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayRemove(b2, dataToRemove1), Seq[Array[Byte]](null, Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1)), ArrayType(ArrayType(IntegerType))) + val dataToRemove2 = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayRemove(c0, dataToRemove2), Seq[Seq[Int]](Seq[Int](3, 4))) + checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + } + + test("Array Distinct") { + val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a", null, "a", null), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType)) + val a5 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 0.1234), + ArrayType(DoubleType)) + val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f), + ArrayType(FloatType)) + + checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5)) + checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer]) + checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c")) + checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a")) + checkEvaluation(new ArrayDistinct(a4), Seq(null)) + checkEvaluation(new ArrayDistinct(a5), Seq(true, false)) + checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121)) + checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), null, Array[Byte](1, 2), + null, Array[Byte](5, 6), null), ArrayType(BinaryType)) + + checkEvaluation(ArrayDistinct(b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2))) + checkEvaluation(ArrayDistinct(b1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayDistinct(b2), Seq[Array[Byte]](Array[Byte](5, 6), null, + Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2), + Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + } + + test("Array Union") { + val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false)) + val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true)) + val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false)) + val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false)) + val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false)) + val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType, containsNull = false)) + val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType, containsNull = false)) + + val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false)) + val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType, containsNull = true)) + val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType, containsNull = false)) + val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType, containsNull = false)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, containsNull = false)) + val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, containsNull = false)) + val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, containsNull = true)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayUnion(a00, a01), Seq(1, 2, 3, 4)) + checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3)) + checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5)) + checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4)) + + checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L)) + checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L)) + checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L)) + + checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f")) + checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g")) + + checkEvaluation(ArrayUnion(a30, a30), Seq(null)) + checkEvaluation(ArrayUnion(a20, a31), null) + checkEvaluation(ArrayUnion(a31, a20), null) + + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]]( + Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](1, 2)), ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType)) + val b6 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + + checkEvaluation(ArrayUnion(b0, b1), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b0, b2), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null)) + checkEvaluation(ArrayUnion(b3, b0), + Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b5), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b6, b4), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b4, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null)) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayUnion(aa0, aa1), + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](5, 6), Seq[Int](2, 1))) + + assert(ArrayUnion(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b4138ce366b3a..726193b411737 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -186,6 +186,50 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("MapFromArrays") { + def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { + // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. + scala.collection.immutable.ListMap(keys.zip(values): _*) + } + + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25) + val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25) + val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_)) + + val intArray = Literal.create(intSeq, ArrayType(IntegerType, false)) + val longArray = Literal.create(longSeq, ArrayType(LongType, false)) + val strArray = Literal.create(strSeq, ArrayType(StringType, false)) + + val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true)) + val intWithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true)) + val longWithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true)) + + val nullArray = Literal.create(null, ArrayType(StringType, false)) + + checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) + checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) + checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) + + checkEvaluation( + MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq)) + checkEvaluation( + MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation( + MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation(MapFromArrays(nullArray, nullArray), null) + + intercept[RuntimeException] { + checkEvaluation(MapFromArrays(intWithNullArray, strArray), null) + } + intercept[RuntimeException] { + checkEvaluation( + MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) + } + } + test("CreateStruct") { val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index a099119732e25..e068c32500cfc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -113,6 +113,76 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true) } + test("if/case when - null flags of non-primitive types") { + val arrayWithNulls = Literal.create(Seq("a", null, "b"), ArrayType(StringType, true)) + val arrayWithoutNulls = Literal.create(Seq("c", "d"), ArrayType(StringType, false)) + val structWithNulls = Literal.create( + create_row(null, null), + StructType(Seq(StructField("a", IntegerType, true), StructField("b", StringType, true)))) + val structWithoutNulls = Literal.create( + create_row(1, "a"), + StructType(Seq(StructField("a", IntegerType, false), StructField("b", StringType, false)))) + val mapWithNulls = Literal.create(Map(1 -> null), MapType(IntegerType, StringType, true)) + val mapWithoutNulls = Literal.create(Map(1 -> "a"), MapType(IntegerType, StringType, false)) + + val arrayIf1 = If(Literal.FalseLiteral, arrayWithNulls, arrayWithoutNulls) + val arrayIf2 = If(Literal.FalseLiteral, arrayWithoutNulls, arrayWithNulls) + val arrayIf3 = If(Literal.TrueLiteral, arrayWithNulls, arrayWithoutNulls) + val arrayIf4 = If(Literal.TrueLiteral, arrayWithoutNulls, arrayWithNulls) + val structIf1 = If(Literal.FalseLiteral, structWithNulls, structWithoutNulls) + val structIf2 = If(Literal.FalseLiteral, structWithoutNulls, structWithNulls) + val structIf3 = If(Literal.TrueLiteral, structWithNulls, structWithoutNulls) + val structIf4 = If(Literal.TrueLiteral, structWithoutNulls, structWithNulls) + val mapIf1 = If(Literal.FalseLiteral, mapWithNulls, mapWithoutNulls) + val mapIf2 = If(Literal.FalseLiteral, mapWithoutNulls, mapWithNulls) + val mapIf3 = If(Literal.TrueLiteral, mapWithNulls, mapWithoutNulls) + val mapIf4 = If(Literal.TrueLiteral, mapWithoutNulls, mapWithNulls) + + val arrayCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithNulls)), arrayWithoutNulls) + val arrayCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithoutNulls)), arrayWithNulls) + val arrayCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithNulls)), arrayWithoutNulls) + val arrayCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithoutNulls)), arrayWithNulls) + val structCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, structWithNulls)), structWithoutNulls) + val structCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, structWithoutNulls)), structWithNulls) + val structCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, structWithNulls)), structWithoutNulls) + val structCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, structWithoutNulls)), structWithNulls) + val mapCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, mapWithNulls)), mapWithoutNulls) + val mapCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, mapWithoutNulls)), mapWithNulls) + val mapCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, mapWithNulls)), mapWithoutNulls) + val mapCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, mapWithoutNulls)), mapWithNulls) + + def checkResult(expectedType: DataType, expectedValue: Any, result: Expression): Unit = { + assert(expectedType == result.dataType) + checkEvaluation(result, expectedValue) + } + + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf1) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf2) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf3) + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf4) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf1) + checkResult(structWithNulls.dataType, structWithNulls.value, structIf2) + checkResult(structWithNulls.dataType, structWithNulls.value, structIf3) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf4) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf1) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf2) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf3) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf4) + + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen1) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen2) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen3) + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen4) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen1) + checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen2) + checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen3) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen4) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen1) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen2) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen3) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen4) + } + test("case key when") { val row = create_row(null, 1, 2, "a", "b", "c") val c1 = 'a.int.at(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 080ec487cfa6a..63b24fb9eb13a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -464,34 +464,47 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { MonthsBetween( Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), - timeZoneId), - 3.94959677) + Literal.TrueLiteral, + timeZoneId = timeZoneId), 3.94959677) checkEvaluation( MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), - timeZoneId), - 0.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - timeZoneId), - -2.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), - timeZoneId), - 1.0) + Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), + Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), + Literal.FalseLiteral, + timeZoneId = timeZoneId), 3.9495967741935485) + + Seq(Literal.FalseLiteral, Literal.TrueLiteral). foreach { roundOff => + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 0.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), -2.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 1.0) + } val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) val tnull = Literal.create(null, TimestampType) - checkEvaluation(MonthsBetween(t, tnull, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, t, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, tnull, timeZoneId), null) + checkEvaluation(MonthsBetween(t, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation(MonthsBetween(tnull, t, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(tnull, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(t, t, Literal.create(null, BooleanType), timeZoneId = timeZoneId), null) checkConsistencyBetweenInterpretedAndCodegen( - (time1: Expression, time2: Expression) => MonthsBetween(time1, time2, timeZoneId), - TimestampType, TimestampType) + (time1: Expression, time2: Expression, roundOff: Expression) => + MonthsBetween(time1, time2, roundOff, timeZoneId = timeZoneId), + TimestampType, TimestampType, BooleanType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4bf6d7107d7e..14bfa212b5496 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. */ -trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { +trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBase { self: SparkFunSuite => protected def create_row(values: Any*): InternalRow = { @@ -98,12 +99,21 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: UnsafeRow, expected: GenericInternalRow) => + val structType = exprDataType.asInstanceOf[StructType] + result.toSeq(structType) == expected.toSeq(structType) case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected } } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: => Expression, + expectedErrMsg: String): Unit = { + checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg) + } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( expression: => Expression, inputRow: InternalRow, @@ -196,39 +206,34 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection) - checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection) - } - - protected def checkEvaluationWithUnsafeProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow, - factory: UnsafeProjectionCreator): Unit = { - val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory) - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - - if (expected == null) { - if (!unsafeRow.isNullAt(0)) { - val expectedRow = InternalRow(expected, expected) - fail("Incorrect evaluation in unsafe mode: " + - s"$expression, actual: $unsafeRow, expected: $expectedRow$input") - } - } else { - val lit = InternalRow(expected, expected) - val expectedRow = - factory.create(Array(expression.dataType, expression.dataType)).apply(lit) - if (unsafeRow != expectedRow) { - fail("Incorrect evaluation in unsafe mode: " + - s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + + if (expected == null) { + if (!unsafeRow.isNullAt(0)) { + val expectedRow = InternalRow(expected, expected) + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } else { + val lit = InternalRow(expected, expected) + val expectedRow = + UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + if (unsafeRow != expectedRow) { + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } } } } protected def evaluateWithUnsafeProjection( expression: Expression, - inputRow: InternalRow = EmptyRow, - factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { + inputRow: InternalRow = EmptyRow): InternalRow = { // SPARK-16489 Explicitly doing code generation twice so code gen will fail if // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 64b65e2070ed6..7c7c4cccee253 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression { override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = - s""" + code""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 7812319756eae..04f1c8ce0b83d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -688,24 +688,28 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { val input = """{ - | "a": 1, - | "c": "foo" - |} - |""".stripMargin + | "a": 1, + | "c": "foo" + |} + |""".stripMargin val jsonSchema = new StructType() .add("a", LongType, nullable = false) .add("b", StringType, nullable = false) .add("c", StringType, nullable = false) val output = InternalRow(1L, null, UTF8String.fromString("foo")) - checkEvaluation( - JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId), - output - ) - val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) - .dataType + val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) + checkEvaluation(expr, output) + val schema = expr.dataType val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema assert(schemaToCompare == schema) } } } + + test("SPARK-24709: infer schema of json strings") { + checkEvaluation(SchemaOfJson(Literal.create("""{"col":0}""")), "struct") + checkEvaluation( + SchemaOfJson(Literal.create("""{"col0":["a"], "col1": {"col2": "b"}}""")), + "struct,col1:struct>") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index a9e0eb0e377a6..86f80fe66d28b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -219,4 +219,11 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) } + + test("SPARK-24571: char literals") { + checkEvaluation(Literal('X'), "X") + checkEvaluation(Literal.create('0'), "0") + checkEvaluation(Literal('\u0000'), "\u0000") + checkEvaluation(Literal.create('\n'), "\n") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala new file mode 100644 index 0000000000000..4d69dc32ace82 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types.{IntegerType, StringType} + +class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("mask") { + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "U", "l", "#"), "llll-UUUU-####-####") + checkEvaluation( + new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l"), Literal("#")), + "llll-UUUU-####-####") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l")), + "llll-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U")), "xxxx-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(new Mask(Literal(null, StringType)), null) + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), null, "l", "#"), "llll-XXXX-####-####") + checkEvaluation(new Mask( + Literal("abcd-EFGH-8765-4321"), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("Upper")), + "xxxx-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("")), "") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("")), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "", "", ""), "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(Mask(Literal("Ul9U"), "\u2200", null, null), "\u2200xn\u2200") + checkEvaluation(new Mask(Literal("Hello World, こんにちは, 𠀋"), Literal("あ"), Literal("𡈽")), + "あ𡈽𡈽𡈽𡈽 あ𡈽𡈽𡈽𡈽, こんにちは, 𠀋") + // scalastyle:on nonascii + intercept[AnalysisException] { + checkEvaluation(new Mask(Literal(""), Literal(1)), "") + } + } + + test("mask_first_n") { + checkEvaluation(MaskFirstN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), + "lU#l-UFGH-8765") + checkEvaluation(new MaskFirstN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "llll-UFGH-8765-4321") + checkEvaluation( + new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "llll-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "xxxx-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "xxxx-XFGH-8765-4321") + intercept[AnalysisException] { + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321")), "xxxx-EFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal(null, StringType)), null) + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "llll-EFGH-8765-4321") + checkEvaluation(new MaskFirstN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-EFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "xxxx-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("")), "") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "xxxx-EFGH-8765-4321") + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "abcd-EFGH-8765-4321") + // scalastyle:off nonascii + checkEvaluation(MaskFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") + checkEvaluation(new MaskFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Xxxxo World") + // scalastyle:on nonascii + } + + test("mask_last_n") { + checkEvaluation(MaskLastN(Literal("abcd-EFGH-aB3d"), 6, "U", "l", "#"), + "abcd-EFGU-lU#l") + checkEvaluation(new MaskLastN( + Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "abcd-EFGU-####") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l")), + "abcd-EFGU-nnnn") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U")), + "abcd-EFGU-nnnn") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6)), + "abcd-EFGX-nnnn") + intercept[AnalysisException] { + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765"), Literal("U")), "") + } + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321")), "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN(Literal(null, StringType)), null) + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(12), Literal("Upper")), + "abcd-EFUU-nnnn-nnnn") + checkEvaluation(new MaskLastN(Literal("")), "") + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(16), Literal("")), + "abcx-XXXX-nnnn-nnnn") + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "abcd-EFGH-8765-4321") + // scalastyle:off nonascii + checkEvaluation(MaskLastN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") + checkEvaluation(new MaskLastN(Literal("あ, 𠀋, Hello World あ 𠀋"), Literal(10)), + "あ, 𠀋, Hello Xxxxx あ 𠀋") + // scalastyle:on nonascii + } + + test("mask_show_first_n") { + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-aB3d"), 6, "U", "l", "#"), + "abcd-EUUU-####-lU#l") + checkEvaluation(new MaskShowFirstN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "abcd-EUUU-####-####") + checkEvaluation( + new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "abcd-EXXX-nnnn-nnnn") + intercept[AnalysisException] { + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321")), "abcd-XXXX-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal(null, StringType)), null) + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "abcd-UUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "abcd-XXXX-nnnn-nnnn") + checkEvaluation( + new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("")), "") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "abcd-XXXX-nnnn-nnnn") + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "abcd-EFGH-8765-4321") + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(MaskShowFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") + checkEvaluation(new MaskShowFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Hellx Xxxxx") + // scalastyle:on nonascii + } + + test("mask_show_last_n") { + checkEvaluation(MaskShowLastN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), + "lU#l-UUUH-8765") + checkEvaluation(new MaskShowLastN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "llll-UUUU-###5-4321") + checkEvaluation( + new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "llll-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "xxxx-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "xxxx-XXXX-nnn5-4321") + intercept[AnalysisException] { + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-4321") + checkEvaluation(new MaskShowLastN(Literal(null, StringType)), null) + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "llll-UUUU-nnnn-4321") + checkEvaluation(new MaskShowLastN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-XXXX-nnnn-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "xxxx-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("")), "") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "xxxx-XXXX-nnnn-4321") + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "abcd-EFGH-8765-4321") + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(MaskShowLastN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") + checkEvaluation(new MaskShowLastN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Xello World") + // scalastyle:on nonascii + } + + test("mask_hash") { + checkEvaluation(MaskHash(Literal("abcd-EFGH-8765-4321")), "60c713f5ec6912229d2060df1c322776") + checkEvaluation(MaskHash(Literal("")), "d41d8cd98f00b204e9800998ecf8427e") + checkEvaluation(MaskHash(Literal(null, StringType)), null) + // scalastyle:off nonascii + checkEvaluation(MaskHash(Literal("\u2200x9U")), "f1243ef123d516b1f32a3a75309e5711") + // scalastyle:on nonascii + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index bcd035c1eba0b..20d568c44258f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -21,12 +21,13 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders._ @@ -37,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class InvokeTargetClass extends Serializable { def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 @@ -80,10 +81,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val structExpected = new GenericArrayData( Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) checkEvaluationWithUnsafeProjection( - structEncoder.serializer.head, - structExpected, - structInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + structEncoder.serializer.head, structExpected, structInputRow) // test UnsafeArray-backed data val arrayEncoder = ExpressionEncoder[Array[Array[Int]]] @@ -91,10 +89,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val arrayExpected = new GenericArrayData( Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) checkEvaluationWithUnsafeProjection( - arrayEncoder.serializer.head, - arrayExpected, - arrayInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + arrayEncoder.serializer.head, arrayExpected, arrayInputRow) // test UnsafeMap-backed data val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]] @@ -108,10 +103,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new GenericArrayData(Array(3, 4)), new GenericArrayData(Array(300, 400))))) checkEvaluationWithUnsafeProjection( - mapEncoder.serializer.head, - mapExpected, - mapInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + mapEncoder.serializer.head, mapExpected, mapInputRow) } test("SPARK-23582: StaticInvoke should support interpreted execution") { @@ -222,7 +214,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.fromObject(new java.util.LinkedList[Int]), Map("nonexisting" -> Literal(1))) checkExceptionInExpression[Exception](initializeWithNonexistingMethod, - InternalRow.fromSeq(Seq()), """A method named "nonexisting" is not declared in any enclosing class """ + "nor any supertype") @@ -286,8 +277,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluationWithUnsafeProjection( expr, expected, - inputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + inputRow) } checkEvaluationWithOptimization(expr, expected, inputRow) } @@ -296,7 +286,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) => - checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) + checkObjectExprEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) } // If an input row or a field are null, a runtime exception will be thrown @@ -472,6 +462,140 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val deserializer = toMapExpr.copy(inputData = Literal.create(data)) checkObjectExprEvaluation(deserializer, expected = data) } + + test("SPARK-23595 ValidateExternalType should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + Seq( + (true, BooleanType), + (2.toByte, ByteType), + (5.toShort, ShortType), + (23, IntegerType), + (61L, LongType), + (1.0f, FloatType), + (10.0, DoubleType), + ("abcd".getBytes, BinaryType), + ("abcd", StringType), + (BigDecimal.valueOf(10), DecimalType.IntDecimal), + (CalendarInterval.fromString("interval 3 day"), CalendarIntervalType), + (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal), + (Array(3, 2, 1), ArrayType(IntegerType)) + ).foreach { case (input, dt) => + val validateType = ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt) + checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input)))) + } + + checkExceptionInExpression[RuntimeException]( + ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType), + InternalRow.fromSeq(Seq(Row(1))), + "java.lang.Integer is not a valid external type for schema of double") + } + + private def javaMapSerializerFor( + keyClazz: Class[_], + valueClazz: Class[_])(inputObject: Expression): Expression = { + + def kvSerializerFor(inputObject: Expression, clazz: Class[_]): Expression = clazz match { + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + } + + ExternalMapToCatalyst( + inputObject, + ObjectType(keyClazz), + kvSerializerFor(_, keyClazz), + keyNullable = true, + ObjectType(valueClazz), + kvSerializerFor(_, valueClazz), + valueNullable = true + ) + } + + private def scalaMapSerializerFor[T: TypeTag, U: TypeTag](inputObject: Expression): Expression = { + import org.apache.spark.sql.catalyst.ScalaReflection._ + + val curId = new java.util.concurrent.atomic.AtomicInteger() + + def kvSerializerFor[V: TypeTag](inputObject: Expression): Expression = + localTypeOf[V].dealias match { + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + case _ => + inputObject + } + + ExternalMapToCatalyst( + inputObject, + dataTypeFor[T], + kvSerializerFor[T], + keyNullable = !localTypeOf[T].typeSymbol.asClass.isPrimitive, + dataTypeFor[U], + kvSerializerFor[U], + valueNullable = !localTypeOf[U].typeSymbol.asClass.isPrimitive + ) + } + + test("SPARK-23589 ExternalMapToCatalyst should support interpreted execution") { + // Simple test + val scalaMap = scala.collection.Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val javaMap = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(0, "v0") + put(1, "v1") + put(2, null) + put(3, "v3") + } + } + val expected = CatalystTypeConverters.convertToCatalyst(scalaMap) + + // Java Map + val serializer1 = javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMap)) + checkEvaluation(serializer1, expected) + + // Scala Map + val serializer2 = scalaMapSerializerFor[Int, String](Literal.fromObject(scalaMap)) + checkEvaluation(serializer2, expected) + + // NULL key test + val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String]( + null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1") + val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(null, "v0") + put(1, "v1") + } + } + + // Java Map + val serializer3 = + javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMapHasNullKey)) + checkExceptionInExpression[RuntimeException]( + serializer3, EmptyRow, "Cannot use null as map key!") + + // Scala Map + val serializer4 = scalaMapSerializerFor[java.lang.Integer, String]( + Literal.fromObject(scalaMapHasNullKey)) + + checkExceptionInExpression[RuntimeException]( + serializer4, EmptyRow, "Cannot use null as map key!") + } } class TestBean extends Serializable { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 1bfd180ae4393..ac76b17ef4761 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -449,4 +449,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false) checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false) } + + test("Interpreted Predicate should initialize nondeterministic expressions") { + val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0))) + interpreted.initialize(0) + assert(interpreted.eval(new UnsafeRow())) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala new file mode 100644 index 0000000000000..cc2e2a993d629 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date, Timestamp} +import java.util.TimeZone + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ + +class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("SortPrefix") { + val b1 = Literal.create(false, BooleanType) + val b2 = Literal.create(true, BooleanType) + val i1 = Literal.create(20132983, IntegerType) + val i2 = Literal.create(-20132983, IntegerType) + val l1 = Literal.create(20132983, LongType) + val l2 = Literal.create(-20132983, LongType) + val millis = 1524954911000L; + // Explicitly choose a time zone, since Date objects can create different values depending on + // local time zone of the machine on which the test is running + val oldDefaultTZ = TimeZone.getDefault + val d1 = try { + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + Literal.create(new java.sql.Date(millis), DateType) + } finally { + TimeZone.setDefault(oldDefaultTZ) + } + val t1 = Literal.create(new Timestamp(millis), TimestampType) + val f1 = Literal.create(0.7788229f, FloatType) + val f2 = Literal.create(-0.7788229f, FloatType) + val db1 = Literal.create(0.7788229d, DoubleType) + val db2 = Literal.create(-0.7788229d, DoubleType) + val s1 = Literal.create("T", StringType) + val s2 = Literal.create("This is longer than 8 characters", StringType) + val bin1 = Literal.create(Array[Byte](12), BinaryType) + val bin2 = Literal.create(Array[Byte](12, 17, 99, 0, 0, 0, 2, 3, 0xf4.asInstanceOf[Byte]), + BinaryType) + val dec1 = Literal(Decimal(20132983L, 10, 2)) + val dec2 = Literal(Decimal(20132983L, 19, 2)) + val dec3 = Literal(Decimal(20132983L, 21, 2)) + val list1 = Literal(List(1, 2), ArrayType(IntegerType)) + val nullVal = Literal.create(null, IntegerType) + + checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L) + checkEvaluation(SortPrefix(SortOrder(b2, Ascending)), 1L) + checkEvaluation(SortPrefix(SortOrder(i1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(i2, Ascending)), -20132983L) + checkEvaluation(SortPrefix(SortOrder(l1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(l2, Ascending)), -20132983L) + // For some reason, the Literal.create code gives us the number of days since the epoch + checkEvaluation(SortPrefix(SortOrder(d1, Ascending)), 17649L) + checkEvaluation(SortPrefix(SortOrder(t1, Ascending)), millis * 1000) + checkEvaluation(SortPrefix(SortOrder(f1, Ascending)), + DoublePrefixComparator.computePrefix(f1.value.asInstanceOf[Float].toDouble)) + checkEvaluation(SortPrefix(SortOrder(f2, Ascending)), + DoublePrefixComparator.computePrefix(f2.value.asInstanceOf[Float].toDouble)) + checkEvaluation(SortPrefix(SortOrder(db1, Ascending)), + DoublePrefixComparator.computePrefix(db1.value.asInstanceOf[Double])) + checkEvaluation(SortPrefix(SortOrder(db2, Ascending)), + DoublePrefixComparator.computePrefix(db2.value.asInstanceOf[Double])) + checkEvaluation(SortPrefix(SortOrder(s1, Ascending)), + StringPrefixComparator.computePrefix(s1.value.asInstanceOf[UTF8String])) + checkEvaluation(SortPrefix(SortOrder(s2, Ascending)), + StringPrefixComparator.computePrefix(s2.value.asInstanceOf[UTF8String])) + checkEvaluation(SortPrefix(SortOrder(bin1, Ascending)), + BinaryPrefixComparator.computePrefix(bin1.value.asInstanceOf[Array[Byte]])) + checkEvaluation(SortPrefix(SortOrder(bin2, Ascending)), + BinaryPrefixComparator.computePrefix(bin2.value.asInstanceOf[Array[Byte]])) + checkEvaluation(SortPrefix(SortOrder(dec1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(dec2, Ascending)), 2013298L) + checkEvaluation(SortPrefix(SortOrder(dec3, Ascending)), + DoublePrefixComparator.computePrefix(201329.83d)) + checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L) + checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index f1a6f9b8889fa..aa334e040d5fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -706,6 +706,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "15,159,339,180,002,773.2778") checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false) + + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##############.###")), "12332.123") + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##.###")), "12332.123") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal("##.####")), "4") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal("##.####")), "4") + checkEvaluation(FormatNumber(Literal(4.0f), Literal("##.###")), "4") + checkEvaluation(FormatNumber(Literal(4), Literal("##.###")), "4") + checkEvaluation(FormatNumber(Literal(12831273.23481d), + Literal("###,###,###,###,###.###")), "12,831,273.235") + checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal("")), "12,831,274") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal("###,###,###,###,###.###")), + "123,123,324,123") + checkEvaluation( + FormatNumber(Literal(Decimal(123123324123L) * Decimal(123123.21234d)), + Literal("###,###,###,###,###.####")), "15,159,339,180,002,773.2778") + checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal("##.###")), null) + assert(FormatNumber(Literal.create(null, NullType), Literal("##.###")).resolved === false) + + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("#,###,###,###,###,###,##0")), + "12,332") + checkEvaluation(FormatNumber( + Literal.create(null, IntegerType), Literal.create(null, StringType)), null) + checkEvaluation(FormatNumber( + Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) } test("find in set") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index c07da122cd7b8..5a646d9a850ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -24,25 +24,30 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String -class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - private def testWithFactory( - name: String)( - f: UnsafeProjectionCreator => Unit): Unit = { - test(name) { - f(UnsafeProjection) - f(InterpretedUnsafeProjection) + private def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = { + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + test(s"$name with $fallbackMode") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + f + } + } } } - testWithFactory("basic conversion with only primitive types") { factory => + testBothCodegenAndInterpreted("basic conversion with only primitive types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) @@ -79,7 +84,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow2.getInt(2) === 2) } - testWithFactory("basic conversion with primitive, string and binary types") { factory => + testBothCodegenAndInterpreted("basic conversion with primitive, string and binary types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) val converter = factory.create(fieldTypes) @@ -98,7 +104,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } - testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory => + testBothCodegenAndInterpreted( + "basic conversion with primitive, string, date and timestamp types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) val converter = factory.create(fieldTypes) @@ -127,7 +135,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { (Timestamp.valueOf("2015-06-22 08:10:25")) } - testWithFactory("null handling") { factory => + testBothCodegenAndInterpreted("null handling") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( NullType, BooleanType, @@ -248,7 +257,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - testWithFactory("NaN canonicalization") { factory => + testBothCodegenAndInterpreted("NaN canonicalization") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) val row1 = new SpecificInternalRow(fieldTypes) @@ -263,7 +273,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } - testWithFactory("basic conversion with struct type") { factory => + testBothCodegenAndInterpreted("basic conversion with struct type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("i", IntegerType), new StructType().add("nest", new StructType().add("l", LongType)) @@ -325,7 +336,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } - testWithFactory("basic conversion with array type") { factory => + testBothCodegenAndInterpreted("basic conversion with array type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( ArrayType(IntegerType), ArrayType(ArrayType(IntegerType)) @@ -355,7 +367,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } - testWithFactory("basic conversion with map type") { factory => + testBothCodegenAndInterpreted("basic conversion with map type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( MapType(IntegerType, IntegerType), MapType(IntegerType, MapType(IntegerType, IntegerType)) @@ -401,7 +414,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } - testWithFactory("basic conversion with struct and array") { factory => + testBothCodegenAndInterpreted("basic conversion with struct and array") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("arr", ArrayType(IntegerType)), ArrayType(new StructType().add("l", LongType)) @@ -440,7 +454,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - testWithFactory("basic conversion with struct and map") { factory => + testBothCodegenAndInterpreted("basic conversion with struct and map") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("map", MapType(IntegerType, IntegerType)), MapType(IntegerType, new StructType().add("l", LongType)) @@ -486,7 +501,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - testWithFactory("basic conversion with array and map") { factory => + testBothCodegenAndInterpreted("basic conversion with array and map") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( ArrayType(MapType(IntegerType, IntegerType)), MapType(IntegerType, ArrayType(IntegerType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala new file mode 100644 index 0000000000000..55569b6f2933e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class CodeBlockSuite extends SparkFunSuite { + + test("Block interpolates string and ExprValue inputs") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val stringLiteral = "false" + val code = code"boolean $isNull = $stringLiteral;" + assert(code.toString == "boolean expr1_isNull = false;") + } + + test("Literals are folded into string code parts instead of block inputs") { + val value = JavaCode.variable("expr1", IntegerType) + val intLiteral = 1 + val code = code"int $value = $intLiteral;" + assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value)) + } + + test("Block.stripMargin") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code1 = + code""" + |boolean $isNull = false; + |int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin + val expected = + s""" + |boolean expr1_isNull = false; + |int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin.trim + assert(code1.toString == expected) + + val code2 = + code""" + >boolean $isNull = false; + >int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin('>') + assert(code2.toString == expected) + } + + test("Block can capture input expr values") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code = + code""" + |boolean $isNull = false; + |int $value = -1; + """.stripMargin + val exprValues = code.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }.toSet + assert(exprValues.size == 2) + assert(exprValues === Set(value, isNull)) + } + + test("concatenate blocks") { + val isNull1 = JavaCode.isNullVariable("expr1_isNull") + val value1 = JavaCode.variable("expr1", IntegerType) + val isNull2 = JavaCode.isNullVariable("expr2_isNull") + val value2 = JavaCode.variable("expr2", IntegerType) + val literal = JavaCode.literal("100", IntegerType) + + val code = + code""" + |boolean $isNull1 = false; + |int $value1 = -1;""".stripMargin + + code""" + |boolean $isNull2 = true; + |int $value2 = $literal;""".stripMargin + + val expected = + """ + |boolean expr1_isNull = false; + |int expr1 = -1; + |boolean expr2_isNull = true; + |int expr2 = 100;""".stripMargin.trim + + assert(code.toString == expected) + + val exprValues = code.children.flatMap(_.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }).toSet + assert(exprValues.size == 5) + assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) + } + + test("Throws exception when interpolating unexcepted object in code block") { + val obj = Tuple2(1, 1) + val e = intercept[IllegalArgumentException] { + code"$obj" + } + assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) + } + + test("transform expr in code block") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val code = + code""" + |callFunc(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + + // We want to replace all occurrences of `expr` with the variable `aliasedParam`. + val aliasedCode = code.transformExprValues { + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam + } + val expected = + code""" + |callFunc(int $aliasedParam) { + | boolean $isNull = false; + | int $exprInFunc = $aliasedParam + 1; + |}""".stripMargin + assert(aliasedCode.toString == expected.toString) + } + + test ("transform expr in nested blocks") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val funcs = Seq("callFunc1", "callFunc2", "callFunc3") + val subBlocks = funcs.map { funcName => + code""" + |$funcName(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + } + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + + val block = code"${subBlocks(0)}\n${subBlocks(1)}\n${subBlocks(2)}" + val transformedBlock = block.transform { + case b: Block => b.transformExprValues { + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam + } + }.asInstanceOf[CodeBlock] + + val expected1 = + code""" + |callFunc1(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected2 = + code""" + |callFunc2(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected3 = + code""" + |callFunc3(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val exprValues = transformedBlock.children.flatMap { block => + block.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + } + }.toSet + + assert(transformedBlock.children(0).toString == expected1.toString) + assert(transformedBlock.children(1).toString == expected2.toString) + assert(transformedBlock.children(2).toString == expected3.toString) + assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString) + assert(exprValues === Set(isNull, exprInFunc, aliasedParam)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala index c4cde7091154b..0fec15bc42c17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala @@ -77,6 +77,27 @@ class UDFXPathUtilSuite extends SparkFunSuite { assert(ret == "foo") } + test("embedFailure") { + import org.apache.commons.io.FileUtils + import java.io.File + val secretValue = String.valueOf(Math.random) + val tempFile = File.createTempFile("verifyembed", ".tmp") + tempFile.deleteOnExit() + val fname = tempFile.getAbsolutePath + + FileUtils.writeStringToFile(tempFile, secretValue) + + val xml = + s""" + | + |]> + |&embed; + """.stripMargin + val evaled = new UDFXPathUtil().evalString(xml, "/foo") + assert(evaled.isEmpty) + } + test("number eval") { var ret = util.evalNumber("truefalseb3c1-77", "a/c[2]") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala index bfa18a0919e45..c6f6d3abb860c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -40,8 +40,9 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { // Test error message for invalid XML document val e1 = intercept[RuntimeException] { testExpr("/a>", "a", null.asInstanceOf[T]) } - assert(e1.getCause.getMessage.contains("Invalid XML document") && - e1.getCause.getMessage.contains("/a>")) + assert(e1.getCause.getCause.getMessage.contains( + "XML document structures must start and end within the same entity.")) + assert(e1.getMessage.contains("/a>")) // Test error message for invalid xpath val e2 = intercept[RuntimeException] { testExpr("", "!#$", null.asInstanceOf[T]) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala index 21220b38968e8..788fedb3c8e8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala @@ -56,7 +56,7 @@ class CheckCartesianProductsSuite extends PlanTest { val thrownException = the [AnalysisException] thrownBy { performCartesianProductCheck(joinType) } - assert(thrownException.message.contains("Detected cartesian product")) + assert(thrownException.message.contains("Detected implicit cartesian product")) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 3f41f4b144096..8b05ba32e6eef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -370,5 +369,13 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized2, expected2.analyze) } + test("SPARK-24696 ColumnPruning rule fails to remove extra Project") { + val input = LocalRelation('key.int, 'value.string) + val query = input.select('key).where(rand(0L) > 0.5).where('key < 10).analyze + val optimized = Optimize.execute(query) + val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze + comparePlans(optimized, expected) + } + // todo: add more tests for column pruning } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e068f51044589..e4671f0d1cce6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -35,11 +35,25 @@ class InferFiltersFromConstraintsSuite extends PlanTest { InferFiltersFromConstraints, CombineFilters, SimplifyBinaryComparison, - BooleanSimplification) :: Nil + BooleanSimplification, + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private def testConstraintsAfterJoin( + x: LogicalPlan, + y: LogicalPlan, + expectedLeft: LogicalPlan, + expectedRight: LogicalPlan, + joinType: JoinType) = { + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, joinType, condition).analyze + val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("filter: filter out constraints in condition") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val correctAnswer = testRelation @@ -196,13 +210,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftSemi, condition).analyze - val left = x.where(IsNotNull('a)) - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftSemi, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi) } test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { @@ -232,12 +240,27 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-21479: Outer join no filter push down to preserved side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze - val left = x - val right = y.where(IsNotNull('a) && 'a === 1) - val correctAnswer = left.join(right, LeftOuter, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin( + x, y.where("a".attr === 1), + x, y.where(IsNotNull('a) && 'a === 1), + LeftOuter) + } + + test("SPARK-23564: left anti join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti) + } + + test("SPARK-23564: left outer join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter) + } + + test("SPARK-23564: right outer join should filter out null join keys on left side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala index 2319ab8046e56..dae5e6f3ee3dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} class RemoveRedundantSortsSuite extends PlanTest { @@ -42,15 +38,15 @@ class RemoveRedundantSortsSuite extends PlanTest { test("remove redundant order by") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst) val optimized = Optimize.execute(unnecessaryReordered.analyze) - val correctAnswer = orderedPlan.select('a).analyze + val correctAnswer = orderedPlan.limit(2).select('a).analyze comparePlans(Optimize.execute(optimized), correctAnswer) } test("do not remove sort if the order is different") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) + val reorderedDifferently = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc) val optimized = Optimize.execute(reorderedDifferently.analyze) val correctAnswer = reorderedDifferently.analyze comparePlans(optimized, correctAnswer) @@ -72,6 +68,14 @@ class RemoveRedundantSortsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("different sorts are not simplified if limit is in between") { + val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10)) + .orderBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = orderedPlan.analyze + comparePlans(optimized, correctAnswer) + } + test("range is already sorted") { val inputPlan = Range(1L, 1000L, 1, 10) val orderedPlan = inputPlan.orderBy('id.asc) @@ -98,4 +102,37 @@ class RemoveRedundantSortsSuite extends PlanTest { val correctAnswer = groupedAndResorted.analyze comparePlans(optimized, correctAnswer) } + + test("remove two consecutive sorts") { + val orderedTwice = testRelation.orderBy('a.asc).orderBy('b.desc) + val optimized = Optimize.execute(orderedTwice.analyze) + val correctAnswer = testRelation.orderBy('b.desc).analyze + comparePlans(optimized, correctAnswer) + } + + test("remove sorts separated by Filter/Project operators") { + val orderedTwiceWithProject = testRelation.orderBy('a.asc).select('b).orderBy('b.desc) + val optimizedWithProject = Optimize.execute(orderedTwiceWithProject.analyze) + val correctAnswerWithProject = testRelation.select('b).orderBy('b.desc).analyze + comparePlans(optimizedWithProject, correctAnswerWithProject) + + val orderedTwiceWithFilter = + testRelation.orderBy('a.asc).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithFilter = Optimize.execute(orderedTwiceWithFilter.analyze) + val correctAnswerWithFilter = testRelation.where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithFilter, correctAnswerWithFilter) + + val orderedTwiceWithBoth = + testRelation.orderBy('a.asc).select('b).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithBoth = Optimize.execute(orderedTwiceWithBoth.analyze) + val correctAnswerWithBoth = + testRelation.select('b).where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithBoth, correctAnswerWithBoth) + + val orderedThrice = orderedTwiceWithBoth.select(('b + 1).as('c)).orderBy('c.asc) + val optimizedThrice = Optimize.execute(orderedThrice.analyze) + val correctAnswerThrice = testRelation.select('b).where('b > Literal(0)) + .select(('b + 1).as('c)).orderBy('c.asc).analyze + comparePlans(optimizedThrice, correctAnswerThrice) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 633d86d495581..5452e72b38647 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -439,4 +439,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select('c as 'sCol2, 'a as 'sCol1) checkRule(originalQuery, correctAnswer) } + + test("SPARK-24313: support binary type as map keys in GetMapValue") { + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 812bfdd7bb885..fb51376c6163f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -318,6 +318,16 @@ class PlanParserSuite extends AnalysisTest { assertEqual( "select * from t lateral view posexplode(x) posexpl as x, y", expected) + + intercept( + """select * + |from t + |lateral view explode(x) expl + |pivot ( + | sum(x) + | FOR y IN ('a', 'b') + |)""".stripMargin, + "LATERAL cannot be used together with PIVOT in FROM clause") } test("joins") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index cc80a41df998d..ff0de0fb7c1f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -41,17 +41,17 @@ class TableIdentifierParserSuite extends SparkFunSuite { "sort", "sorted", "ssl", "statistics", "stored", "streamtable", "string", "struct", "tables", "tblproperties", "temporary", "terminated", "tinyint", "touch", "transactions", "unarchive", "undo", "uniontype", "unlock", "unset", "unsigned", "uri", "use", "utc", "utctimestamp", - "view", "while", "year", "work", "transaction", "write", "isolation", "level", - "snapshot", "autocommit", "all", "alter", "array", "as", "authorization", "between", "bigint", + "view", "while", "year", "work", "transaction", "write", "isolation", "level", "snapshot", + "autocommit", "all", "any", "alter", "array", "as", "authorization", "between", "bigint", "binary", "boolean", "both", "by", "create", "cube", "current_date", "current_timestamp", "cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external", "false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in", - "insert", "int", "into", "is", "lateral", "like", "local", "none", "null", + "insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null", "of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", - "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing") + "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing", "extract") val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right", "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 14041747fd20e..bf569cb869428 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -101,4 +101,22 @@ class LogicalPlanSuite extends SparkFunSuite { assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true) assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming) } + + test("transformExpressions works with a Stream") { + val id1 = NamedExpression.newExprId + val id2 = NamedExpression.newExprId + val plan = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(2), "b")(exprId = id2)), + OneRowRelation()) + val result = plan.transformExpressions { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(3), "b")(exprId = id2)), + OneRowRelation()) + assert(result.sameResult(expected)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 43440d51dede6..47bfa62569583 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -357,6 +357,29 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 3) } + test("evaluateInSet with all zeros") { + validateEstimatedStats( + Filter(InSet(attrString, Set(3, 4, 5)), + StatsTestPlan(Seq(attrString), 0, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(0), maxLen = Some(0)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(0))), + expectedRowCount = 0) + } + + test("evaluateInSet with string") { + validateEstimatedStats( + Filter(InSet(attrString, Set("A0")), + StatsTestPlan(Seq(attrString), 10, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), + expectedRowCount = 1) + } + test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 84d0ba7bef642..b7092f4c42d4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -29,14 +29,14 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions.DslString import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union} import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition} -import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { @@ -574,4 +574,25 @@ class TreeNodeSuite extends SparkFunSuite { val right = JsonMethods.parse(rightJson) assert(left == right) } + + test("transform works on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the + // situation in which the TreeNode.mapChildren function's change detection is not triggered. A + // stream's first element is typically materialized, so in order to not trip the TreeNode change + // detection logic, we should not change the first element in the sequence. + val result = before.transform { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } + + test("withNewChildren on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + val result = before.withNewChildren(Stream(Literal(1), Literal(3))) + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala index 9d285916bcf42..229e32479082c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala @@ -104,4 +104,40 @@ class ComplexDataSuite extends SparkFunSuite { // The copied data should not be changed externally. assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") } + + test("SPARK-24659: GenericArrayData.equals should respect element type differences") { + import scala.reflect.ClassTag + + // Expected positive cases + def arraysShouldEqual[T: ClassTag](element: T*): Unit = { + val array1 = new GenericArrayData(Array[T](element: _*)) + val array2 = new GenericArrayData(Array[T](element: _*)) + assert(array1.equals(array2)) + } + arraysShouldEqual(true, false) // Boolean + arraysShouldEqual(0.toByte, 123.toByte, Byte.MinValue, Byte.MaxValue) // Byte + arraysShouldEqual(0.toShort, 123.toShort, Short.MinValue, Short.MaxValue) // Short + arraysShouldEqual(0, 123, -65536, Int.MinValue, Int.MaxValue) // Int + arraysShouldEqual(0L, 123L, -65536L, Long.MinValue, Long.MaxValue) // Long + arraysShouldEqual(0.0F, 123.0F, Float.MinValue, Float.MaxValue, Float.MinPositiveValue, + Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN) // Float + arraysShouldEqual(0.0, 123.0, Double.MinValue, Double.MaxValue, Double.MinPositiveValue, + Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN) // Double + arraysShouldEqual(Array[Byte](123.toByte), Array[Byte](), null) // SQL Binary + arraysShouldEqual(UTF8String.fromString("foo"), null) // SQL String + + // Expected negative cases + // Spark SQL considers cases like array vs array to be incompatible, + // so an underlying implementation of array type should return false in such cases. + def arraysShouldNotEqual[T: ClassTag, U: ClassTag](element1: T, element2: U): Unit = { + val array1 = new GenericArrayData(Array[T](element1)) + val array2 = new GenericArrayData(Array[U](element2)) + assert(!array1.equals(array2)) + } + arraysShouldNotEqual(true, 1) // Boolean <-> Int + arraysShouldNotEqual(123.toByte, 123) // Byte <-> Int + arraysShouldNotEqual(123.toByte, 123L) // Byte <-> Long + arraysShouldNotEqual(123.toShort, 123) // Short <-> Int + arraysShouldNotEqual(123, 123L) // Int <-> Long + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 625ff38943fa3..cbf6106697f30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -490,24 +490,36 @@ class DateTimeUtilsSuite extends SparkFunSuite { c1.set(1997, 1, 28, 10, 30, 0) val c2 = Calendar.getInstance() c2.set(1996, 9, 30, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677) - c2.set(2000, 1, 28, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(2000, 1, 29, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(1996, 2, 31, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, true, c1.getTimeZone) === 3.94959677) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, false, c1.getTimeZone) + === 3.9495967741935485) + Seq(true, false).foreach { roundOff => + c2.set(2000, 1, 28, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(2000, 1, 29, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(1996, 2, 31, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === 11) + } val c3 = Calendar.getInstance(TimeZonePST) c3.set(2000, 1, 28, 16, 0, 0) val c4 = Calendar.getInstance(TimeZonePST) c4.set(1997, 1, 28, 16, 0, 0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZonePST) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZonePST) === 36.0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZoneGMT) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZoneGMT) === 35.90322581) + assert( + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, false, TimeZoneGMT) + === 35.903225806451616) } test("from UTC timestamp") { diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt new file mode 100644 index 0000000000000..4f38cc4cee96d --- /dev/null +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -0,0 +1,704 @@ +================================================================================================ +Pushdown for many distinct value case +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8970 / 9122 1.8 570.3 1.0X +Parquet Vectorized (Pushdown) 471 / 491 33.4 30.0 19.0X +Native ORC Vectorized 7661 / 7853 2.1 487.0 1.2X +Native ORC Vectorized (Pushdown) 1134 / 1161 13.9 72.1 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 string row ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9246 / 9297 1.7 587.8 1.0X +Parquet Vectorized (Pushdown) 480 / 488 32.8 30.5 19.3X +Native ORC Vectorized 7838 / 7850 2.0 498.3 1.2X +Native ORC Vectorized (Pushdown) 1054 / 1118 14.9 67.0 8.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8989 / 9100 1.7 571.5 1.0X +Parquet Vectorized (Pushdown) 448 / 467 35.1 28.5 20.1X +Native ORC Vectorized 7680 / 7768 2.0 488.3 1.2X +Native ORC Vectorized (Pushdown) 1067 / 1118 14.7 67.8 8.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9115 / 9266 1.7 579.5 1.0X +Parquet Vectorized (Pushdown) 466 / 492 33.7 29.7 19.5X +Native ORC Vectorized 7800 / 7914 2.0 495.9 1.2X +Native ORC Vectorized (Pushdown) 1075 / 1102 14.6 68.4 8.5X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9099 / 9237 1.7 578.5 1.0X +Parquet Vectorized (Pushdown) 462 / 475 34.1 29.3 19.7X +Native ORC Vectorized 7847 / 7925 2.0 498.9 1.2X +Native ORC Vectorized (Pushdown) 1078 / 1114 14.6 68.5 8.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 19303 / 19547 0.8 1227.3 1.0X +Parquet Vectorized (Pushdown) 19924 / 20089 0.8 1266.7 1.0X +Native ORC Vectorized 18725 / 19079 0.8 1190.5 1.0X +Native ORC Vectorized (Pushdown) 19310 / 19492 0.8 1227.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8117 / 8323 1.9 516.1 1.0X +Parquet Vectorized (Pushdown) 484 / 494 32.5 30.8 16.8X +Native ORC Vectorized 6811 / 7036 2.3 433.0 1.2X +Native ORC Vectorized (Pushdown) 1061 / 1082 14.8 67.5 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 int row (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8105 / 8140 1.9 515.3 1.0X +Parquet Vectorized (Pushdown) 478 / 505 32.9 30.4 17.0X +Native ORC Vectorized 6914 / 7211 2.3 439.6 1.2X +Native ORC Vectorized (Pushdown) 1044 / 1064 15.1 66.4 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7983 / 8116 2.0 507.6 1.0X +Parquet Vectorized (Pushdown) 464 / 487 33.9 29.5 17.2X +Native ORC Vectorized 6703 / 6774 2.3 426.1 1.2X +Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7942 / 7983 2.0 504.9 1.0X +Parquet Vectorized (Pushdown) 468 / 479 33.6 29.7 17.0X +Native ORC Vectorized 6677 / 6779 2.4 424.5 1.2X +Native ORC Vectorized (Pushdown) 1021 / 1068 15.4 64.9 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7909 / 7958 2.0 502.8 1.0X +Parquet Vectorized (Pushdown) 485 / 494 32.4 30.8 16.3X +Native ORC Vectorized 6751 / 6846 2.3 429.2 1.2X +Native ORC Vectorized (Pushdown) 1043 / 1077 15.1 66.3 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8010 / 8033 2.0 509.2 1.0X +Parquet Vectorized (Pushdown) 472 / 489 33.3 30.0 17.0X +Native ORC Vectorized 6655 / 6808 2.4 423.1 1.2X +Native ORC Vectorized (Pushdown) 1015 / 1067 15.5 64.5 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8983 / 9035 1.8 571.1 1.0X +Parquet Vectorized (Pushdown) 2204 / 2231 7.1 140.1 4.1X +Native ORC Vectorized 7864 / 8011 2.0 500.0 1.1X +Native ORC Vectorized (Pushdown) 2674 / 2789 5.9 170.0 3.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12723 / 12903 1.2 808.9 1.0X +Parquet Vectorized (Pushdown) 9112 / 9282 1.7 579.3 1.4X +Native ORC Vectorized 12090 / 12230 1.3 768.7 1.1X +Native ORC Vectorized (Pushdown) 9242 / 9372 1.7 587.6 1.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16453 / 16678 1.0 1046.1 1.0X +Parquet Vectorized (Pushdown) 15997 / 16262 1.0 1017.0 1.0X +Native ORC Vectorized 16652 / 17070 0.9 1058.7 1.0X +Native ORC Vectorized (Pushdown) 15843 / 16112 1.0 1007.2 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 17098 / 17254 0.9 1087.1 1.0X +Parquet Vectorized (Pushdown) 17302 / 17529 0.9 1100.1 1.0X +Native ORC Vectorized 16790 / 17098 0.9 1067.5 1.0X +Native ORC Vectorized (Pushdown) 17329 / 17914 0.9 1101.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 17088 / 17392 0.9 1086.4 1.0X +Parquet Vectorized (Pushdown) 17609 / 17863 0.9 1119.5 1.0X +Native ORC Vectorized 18334 / 69831 0.9 1165.7 0.9X +Native ORC Vectorized (Pushdown) 17465 / 17629 0.9 1110.4 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16903 / 17233 0.9 1074.6 1.0X +Parquet Vectorized (Pushdown) 16945 / 17032 0.9 1077.3 1.0X +Native ORC Vectorized 16377 / 16762 1.0 1041.2 1.0X +Native ORC Vectorized (Pushdown) 16950 / 17212 0.9 1077.7 1.0X + + +================================================================================================ +Pushdown for few distinct value case (use dictionary encoding) +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 distinct string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7245 / 7322 2.2 460.7 1.0X +Parquet Vectorized (Pushdown) 378 / 389 41.6 24.0 19.2X +Native ORC Vectorized 6720 / 6778 2.3 427.2 1.1X +Native ORC Vectorized (Pushdown) 1009 / 1032 15.6 64.2 7.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 distinct string row ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7627 / 7795 2.1 484.9 1.0X +Parquet Vectorized (Pushdown) 384 / 406 41.0 24.4 19.9X +Native ORC Vectorized 6724 / 7824 2.3 427.5 1.1X +Native ORC Vectorized (Pushdown) 968 / 986 16.3 61.5 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7157 / 7534 2.2 455.0 1.0X +Parquet Vectorized (Pushdown) 542 / 565 29.0 34.5 13.2X +Native ORC Vectorized 6716 / 7214 2.3 427.0 1.1X +Native ORC Vectorized (Pushdown) 1212 / 1288 13.0 77.0 5.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7368 / 7552 2.1 468.4 1.0X +Parquet Vectorized (Pushdown) 544 / 556 28.9 34.6 13.5X +Native ORC Vectorized 6740 / 6867 2.3 428.5 1.1X +Native ORC Vectorized (Pushdown) 1230 / 1426 12.8 78.2 6.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7427 / 7734 2.1 472.2 1.0X +Parquet Vectorized (Pushdown) 556 / 568 28.3 35.4 13.3X +Native ORC Vectorized 6847 / 7059 2.3 435.3 1.1X +Native ORC Vectorized (Pushdown) 1226 / 1230 12.8 77.9 6.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all distinct string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16998 / 17311 0.9 1080.7 1.0X +Parquet Vectorized (Pushdown) 16977 / 17250 0.9 1079.4 1.0X +Native ORC Vectorized 18447 / 19852 0.9 1172.8 0.9X +Native ORC Vectorized (Pushdown) 16614 / 17102 0.9 1056.3 1.0X + + +================================================================================================ +Pushdown benchmark for StringStartsWith +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '10%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9705 / 10814 1.6 617.0 1.0X +Parquet Vectorized (Pushdown) 3086 / 3574 5.1 196.2 3.1X +Native ORC Vectorized 10094 / 10695 1.6 641.8 1.0X +Native ORC Vectorized (Pushdown) 9611 / 9999 1.6 611.0 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '1000%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8016 / 8183 2.0 509.7 1.0X +Parquet Vectorized (Pushdown) 444 / 457 35.4 28.2 18.0X +Native ORC Vectorized 6970 / 7169 2.3 443.2 1.2X +Native ORC Vectorized (Pushdown) 7447 / 7503 2.1 473.5 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '786432%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7908 / 8046 2.0 502.8 1.0X +Parquet Vectorized (Pushdown) 408 / 429 38.6 25.9 19.4X +Native ORC Vectorized 7021 / 7100 2.2 446.4 1.1X +Native ORC Vectorized (Pushdown) 7310 / 7490 2.2 464.8 1.1X + + +================================================================================================ +Pushdown benchmark for decimal +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(9, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 3785 / 3867 4.2 240.6 1.0X +Parquet Vectorized (Pushdown) 3820 / 3928 4.1 242.9 1.0X +Native ORC Vectorized 3981 / 4049 4.0 253.1 1.0X +Native ORC Vectorized (Pushdown) 702 / 735 22.4 44.6 5.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(9, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4694 / 4813 3.4 298.4 1.0X +Parquet Vectorized (Pushdown) 4839 / 4907 3.3 307.6 1.0X +Native ORC Vectorized 4943 / 5032 3.2 314.2 0.9X +Native ORC Vectorized (Pushdown) 2043 / 2085 7.7 129.9 2.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(9, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8321 / 8472 1.9 529.0 1.0X +Parquet Vectorized (Pushdown) 8125 / 8471 1.9 516.6 1.0X +Native ORC Vectorized 8524 / 8616 1.8 541.9 1.0X +Native ORC Vectorized (Pushdown) 7961 / 8383 2.0 506.1 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(9, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9587 / 10112 1.6 609.5 1.0X +Parquet Vectorized (Pushdown) 9726 / 10370 1.6 618.3 1.0X +Native ORC Vectorized 10119 / 11147 1.6 643.4 0.9X +Native ORC Vectorized (Pushdown) 9366 / 9497 1.7 595.5 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(18, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4060 / 4093 3.9 258.1 1.0X +Parquet Vectorized (Pushdown) 4037 / 4125 3.9 256.6 1.0X +Native ORC Vectorized 4756 / 4811 3.3 302.4 0.9X +Native ORC Vectorized (Pushdown) 824 / 889 19.1 52.4 4.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(18, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5157 / 5271 3.0 327.9 1.0X +Parquet Vectorized (Pushdown) 5051 / 5141 3.1 321.1 1.0X +Native ORC Vectorized 5723 / 6146 2.7 363.9 0.9X +Native ORC Vectorized (Pushdown) 2198 / 2317 7.2 139.8 2.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(18, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8608 / 8647 1.8 547.3 1.0X +Parquet Vectorized (Pushdown) 8471 / 8584 1.9 538.6 1.0X +Native ORC Vectorized 9249 / 10048 1.7 588.0 0.9X +Native ORC Vectorized (Pushdown) 7645 / 8091 2.1 486.1 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(18, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11658 / 11888 1.3 741.2 1.0X +Parquet Vectorized (Pushdown) 11812 / 12098 1.3 751.0 1.0X +Native ORC Vectorized 12943 / 13312 1.2 822.9 0.9X +Native ORC Vectorized (Pushdown) 13139 / 13465 1.2 835.4 0.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(38, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5491 / 5716 2.9 349.1 1.0X +Parquet Vectorized (Pushdown) 5515 / 5615 2.9 350.6 1.0X +Native ORC Vectorized 4582 / 4654 3.4 291.3 1.2X +Native ORC Vectorized (Pushdown) 815 / 861 19.3 51.8 6.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(38, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6432 / 6527 2.4 409.0 1.0X +Parquet Vectorized (Pushdown) 6513 / 6607 2.4 414.1 1.0X +Native ORC Vectorized 5618 / 6085 2.8 357.2 1.1X +Native ORC Vectorized (Pushdown) 2403 / 2443 6.5 152.8 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(38, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11041 / 11467 1.4 701.9 1.0X +Parquet Vectorized (Pushdown) 10909 / 11484 1.4 693.5 1.0X +Native ORC Vectorized 9860 / 10436 1.6 626.9 1.1X +Native ORC Vectorized (Pushdown) 7908 / 8069 2.0 502.8 1.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(38, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 14816 / 16877 1.1 942.0 1.0X +Parquet Vectorized (Pushdown) 15383 / 15740 1.0 978.0 1.0X +Native ORC Vectorized 14408 / 14771 1.1 916.0 1.0X +Native ORC Vectorized (Pushdown) 13968 / 14805 1.1 888.1 1.1X + + +================================================================================================ +Pushdown benchmark for InSet -> InFilters +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7993 / 8104 2.0 508.2 1.0X +Parquet Vectorized (Pushdown) 507 / 532 31.0 32.2 15.8X +Native ORC Vectorized 6922 / 7163 2.3 440.1 1.2X +Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7855 / 7963 2.0 499.4 1.0X +Parquet Vectorized (Pushdown) 503 / 516 31.3 32.0 15.6X +Native ORC Vectorized 6825 / 6954 2.3 433.9 1.2X +Native ORC Vectorized (Pushdown) 1019 / 1044 15.4 64.8 7.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7858 / 7928 2.0 499.6 1.0X +Parquet Vectorized (Pushdown) 490 / 519 32.1 31.1 16.0X +Native ORC Vectorized 7079 / 7966 2.2 450.1 1.1X +Native ORC Vectorized (Pushdown) 1276 / 1673 12.3 81.1 6.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8007 / 11155 2.0 509.0 1.0X +Parquet Vectorized (Pushdown) 519 / 540 30.3 33.0 15.4X +Native ORC Vectorized 6848 / 7072 2.3 435.4 1.2X +Native ORC Vectorized (Pushdown) 1026 / 1050 15.3 65.2 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7876 / 7956 2.0 500.7 1.0X +Parquet Vectorized (Pushdown) 521 / 535 30.2 33.1 15.1X +Native ORC Vectorized 7051 / 7368 2.2 448.3 1.1X +Native ORC Vectorized (Pushdown) 1014 / 1035 15.5 64.5 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7897 / 8229 2.0 502.1 1.0X +Parquet Vectorized (Pushdown) 513 / 530 30.7 32.6 15.4X +Native ORC Vectorized 6730 / 6990 2.3 427.9 1.2X +Native ORC Vectorized (Pushdown) 1003 / 1036 15.7 63.8 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7967 / 8175 2.0 506.5 1.0X +Parquet Vectorized (Pushdown) 8155 / 8434 1.9 518.5 1.0X +Native ORC Vectorized 7002 / 7107 2.2 445.2 1.1X +Native ORC Vectorized (Pushdown) 1092 / 1139 14.4 69.4 7.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8032 / 8122 2.0 510.7 1.0X +Parquet Vectorized (Pushdown) 8141 / 8908 1.9 517.6 1.0X +Native ORC Vectorized 7140 / 7387 2.2 454.0 1.1X +Native ORC Vectorized (Pushdown) 1156 / 1220 13.6 73.5 6.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8088 / 8350 1.9 514.2 1.0X +Parquet Vectorized (Pushdown) 8629 / 8702 1.8 548.6 0.9X +Native ORC Vectorized 7480 / 7886 2.1 475.6 1.1X +Native ORC Vectorized (Pushdown) 1106 / 1145 14.2 70.3 7.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8028 / 8165 2.0 510.4 1.0X +Parquet Vectorized (Pushdown) 8349 / 8674 1.9 530.8 1.0X +Native ORC Vectorized 7107 / 7354 2.2 451.8 1.1X +Native ORC Vectorized (Pushdown) 1175 / 1207 13.4 74.7 6.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8041 / 8195 2.0 511.2 1.0X +Parquet Vectorized (Pushdown) 8466 / 8604 1.9 538.2 0.9X +Native ORC Vectorized 7116 / 7286 2.2 452.4 1.1X +Native ORC Vectorized (Pushdown) 1197 / 1214 13.1 76.1 6.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7998 / 8311 2.0 508.5 1.0X +Parquet Vectorized (Pushdown) 9366 / 11257 1.7 595.5 0.9X +Native ORC Vectorized 7856 / 9273 2.0 499.5 1.0X +Native ORC Vectorized (Pushdown) 1350 / 1747 11.7 85.8 5.9X + + +================================================================================================ +Pushdown benchmark for tinyint +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 tinyint row (value = CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 3461 / 3997 4.5 220.1 1.0X +Parquet Vectorized (Pushdown) 270 / 315 58.4 17.1 12.8X +Native ORC Vectorized 4107 / 5372 3.8 261.1 0.8X +Native ORC Vectorized (Pushdown) 778 / 1553 20.2 49.5 4.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% tinyint rows (value < CAST(12 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4771 / 6655 3.3 303.3 1.0X +Parquet Vectorized (Pushdown) 1322 / 1606 11.9 84.0 3.6X +Native ORC Vectorized 4437 / 4572 3.5 282.1 1.1X +Native ORC Vectorized (Pushdown) 1781 / 1976 8.8 113.2 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% tinyint rows (value < CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7433 / 7752 2.1 472.6 1.0X +Parquet Vectorized (Pushdown) 5863 / 5913 2.7 372.8 1.3X +Native ORC Vectorized 7986 / 8084 2.0 507.7 0.9X +Native ORC Vectorized (Pushdown) 6522 / 6608 2.4 414.6 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% tinyint rows (value < CAST(114 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11190 / 11519 1.4 711.4 1.0X +Parquet Vectorized (Pushdown) 10861 / 11206 1.4 690.5 1.0X +Native ORC Vectorized 11622 / 12196 1.4 738.9 1.0X +Native ORC Vectorized (Pushdown) 11377 / 11654 1.4 723.3 1.0X + + +================================================================================================ +Pushdown benchmark for Timestamp +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as INT96 row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4784 / 4956 3.3 304.2 1.0X +Parquet Vectorized (Pushdown) 4838 / 4917 3.3 307.6 1.0X +Native ORC Vectorized 3923 / 4173 4.0 249.4 1.2X +Native ORC Vectorized (Pushdown) 894 / 943 17.6 56.8 5.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as INT96 rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5686 / 5901 2.8 361.5 1.0X +Parquet Vectorized (Pushdown) 5555 / 5895 2.8 353.2 1.0X +Native ORC Vectorized 4844 / 4957 3.2 308.0 1.2X +Native ORC Vectorized (Pushdown) 2141 / 2230 7.3 136.1 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as INT96 rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9100 / 9421 1.7 578.6 1.0X +Parquet Vectorized (Pushdown) 9122 / 9496 1.7 580.0 1.0X +Native ORC Vectorized 8365 / 8874 1.9 531.9 1.1X +Native ORC Vectorized (Pushdown) 7128 / 7376 2.2 453.2 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as INT96 rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12764 / 13120 1.2 811.5 1.0X +Parquet Vectorized (Pushdown) 12656 / 13003 1.2 804.7 1.0X +Native ORC Vectorized 13096 / 13233 1.2 832.6 1.0X +Native ORC Vectorized (Pushdown) 12710 / 15611 1.2 808.1 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as TIMESTAMP_MICROS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4381 / 4796 3.6 278.5 1.0X +Parquet Vectorized (Pushdown) 122 / 137 129.3 7.7 36.0X +Native ORC Vectorized 3913 / 3988 4.0 248.8 1.1X +Native ORC Vectorized (Pushdown) 905 / 945 17.4 57.6 4.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5145 / 5184 3.1 327.1 1.0X +Parquet Vectorized (Pushdown) 1426 / 1519 11.0 90.7 3.6X +Native ORC Vectorized 4827 / 4901 3.3 306.9 1.1X +Native ORC Vectorized (Pushdown) 2133 / 2210 7.4 135.6 2.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9234 / 9516 1.7 587.1 1.0X +Parquet Vectorized (Pushdown) 6752 / 7046 2.3 429.3 1.4X +Native ORC Vectorized 8418 / 8998 1.9 535.2 1.1X +Native ORC Vectorized (Pushdown) 7199 / 7314 2.2 457.7 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12414 / 12458 1.3 789.2 1.0X +Parquet Vectorized (Pushdown) 12094 / 12249 1.3 768.9 1.0X +Native ORC Vectorized 12198 / 13755 1.3 775.5 1.0X +Native ORC Vectorized (Pushdown) 12205 / 12431 1.3 776.0 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as TIMESTAMP_MILLIS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4369 / 4515 3.6 277.8 1.0X +Parquet Vectorized (Pushdown) 116 / 125 136.2 7.3 37.8X +Native ORC Vectorized 3965 / 4703 4.0 252.1 1.1X +Native ORC Vectorized (Pushdown) 892 / 1162 17.6 56.7 4.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5211 / 5409 3.0 331.3 1.0X +Parquet Vectorized (Pushdown) 1427 / 1438 11.0 90.7 3.7X +Native ORC Vectorized 4719 / 4883 3.3 300.1 1.1X +Native ORC Vectorized (Pushdown) 2191 / 2228 7.2 139.3 2.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8716 / 8953 1.8 554.2 1.0X +Parquet Vectorized (Pushdown) 6632 / 6968 2.4 421.7 1.3X +Native ORC Vectorized 8376 / 9118 1.9 532.5 1.0X +Native ORC Vectorized (Pushdown) 7218 / 7609 2.2 458.9 1.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12264 / 12452 1.3 779.7 1.0X +Parquet Vectorized (Pushdown) 11766 / 11927 1.3 748.0 1.0X +Native ORC Vectorized 12101 / 12301 1.3 769.3 1.0X +Native ORC Vectorized (Pushdown) 11983 / 12651 1.3 761.9 1.0X + diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ef41837f89d68..18ae314309d7b 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -38,7 +38,7 @@ com.univocity univocity-parsers - 2.5.9 + 2.6.3 jar @@ -118,7 +118,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded org.scalacheck diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index c7c4c7b3e7715..c8cf44b51df77 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -20,8 +20,8 @@ import java.io.IOException; import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; import org.apache.spark.internal.config.package$; -import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -82,7 +82,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. * @param groupingKeySchema the schema of the grouping key, used for row conversion. - * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. + * @param taskContext the current task context. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. */ @@ -90,19 +90,26 @@ public UnsafeFixedWidthAggregationMap( InternalRow emptyAggregationBuffer, StructType aggregationBufferSchema, StructType groupingKeySchema, - TaskMemoryManager taskMemoryManager, + TaskContext taskContext, int initialCapacity, long pageSizeBytes) { this.aggregationBufferSchema = aggregationBufferSchema; this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = - new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, true); + this.map = new BytesToBytesMap( + taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true); // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the aggregation map's output (e.g. aggregate followed by limit). + taskContext.addTaskCompletionListener(context -> { + free(); + }); } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 12f4d658b1868..9bfad1e83ee7b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -136,7 +136,7 @@ public int getInt(int rowId) { public long getLong(int rowId) { int index = getRowIndex(rowId); if (isTimestamp) { - return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000; + return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000 % 1000; } else { return longData.vector[index]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index dcebdc39f0aa2..a0d9578a377b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -497,7 +497,7 @@ private void putValues( * Returns the number of micros since epoch from an element of TimestampColumnVector. */ private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) { - return vector.time[index] * 1000L + vector.nanos[index] / 1000L; + return vector.time[index] * 1000 + (vector.nanos[index] / 1000 % 1000); } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index e65cd252c3ddf..c975e52734e01 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.parquet; -import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; import java.lang.reflect.InvocationTargetException; @@ -147,7 +146,8 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString); this.reader = new ParquetFileReader( configuration, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); - for (BlockMetaData block : blocks) { + // use the blocks from the reader in case some do not match filters and will not be read + for (BlockMetaData block : reader.getRowGroups()) { this.totalRowCount += block.getRowCount(); } @@ -225,7 +225,8 @@ protected void initialize(String path, List columns) throws IOException this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema); this.reader = new ParquetFileReader( config, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); - for (BlockMetaData block : blocks) { + // use the blocks from the reader in case some do not match filters and will not be read + for (BlockMetaData block : reader.getRowGroups()) { this.totalRowCount += block.getRowCount(); } } @@ -293,7 +294,7 @@ protected static IntIterator createRLEIterator( return new RLEIntIterator( new RunLengthBitPackingHybridDecoder( BytesUtils.getWidthFromMaxInt(maxLevel), - new ByteArrayInputStream(bytes.toByteArray()))); + bytes.toInputStream())); } catch (IOException e) { throw new IOException("could not read levels in page for col " + descriptor, e); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 72f1d024b08ce..d5969b55eef96 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.TimeZone; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.bytes.BytesInput; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; @@ -388,7 +390,8 @@ private void decodeDictionaryIds( * is guaranteed that num is smaller than the number of values left in the current page. */ - private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { + private void readBooleanBatch(int rowId, int num, WritableColumnVector column) + throws IOException { if (column.dataType() != DataTypes.BooleanType) { throw constructConvertNotSupportedException(descriptor, column); } @@ -396,7 +399,7 @@ private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } - private void readIntBatch(int rowId, int num, WritableColumnVector column) { + private void readIntBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || @@ -414,7 +417,7 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) { } } - private void readLongBatch(int rowId, int num, WritableColumnVector column) { + private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. if (column.dataType() == DataTypes.LongType || DecimalType.is64BitDecimalType(column.dataType()) || @@ -434,7 +437,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) { } } - private void readFloatBatch(int rowId, int num, WritableColumnVector column) { + private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: support implicit cast to double? if (column.dataType() == DataTypes.FloatType) { @@ -445,7 +448,7 @@ private void readFloatBatch(int rowId, int num, WritableColumnVector column) { } } - private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { + private void readDoubleBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.DoubleType) { @@ -456,7 +459,7 @@ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { } } - private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { + private void readBinaryBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; @@ -556,7 +559,7 @@ public Void visit(DataPageV2 dataPageV2) { }); } - private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) throws IOException { + private void initDataReader(Encoding dataEncoding, ByteBufferInputStream in) throws IOException { this.endOfPageValueCount = valuesRead + pageValueCount; if (dataEncoding.usesDictionary()) { this.dataColumn = null; @@ -581,7 +584,7 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr } try { - dataColumn.initFromPage(pageValueCount, bytes, offset); + dataColumn.initFromPage(pageValueCount, in); } catch (IOException e) { throw new IOException("could not read page in col " + descriptor, e); } @@ -602,12 +605,11 @@ private void readPageV1(DataPageV1 page) throws IOException { this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); try { - byte[] bytes = page.getBytes().toByteArray(); - rlReader.initFromPage(pageValueCount, bytes, 0); - int next = rlReader.getNextOffset(); - dlReader.initFromPage(pageValueCount, bytes, next); - next = dlReader.getNextOffset(); - initDataReader(page.getValueEncoding(), bytes, next); + BytesInput bytes = page.getBytes(); + ByteBufferInputStream in = bytes.toInputStream(); + rlReader.initFromPage(pageValueCount, in); + dlReader.initFromPage(pageValueCount, in); + initDataReader(page.getValueEncoding(), in); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); } @@ -619,12 +621,13 @@ private void readPageV2(DataPageV2 page) throws IOException { page.getRepetitionLevels(), descriptor); int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); - this.defColumn = new VectorizedRleValuesReader(bitWidth); + // do not read the length from the stream. v2 pages handle dividing the page bytes. + this.defColumn = new VectorizedRleValuesReader(bitWidth, false); this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn); - this.defColumn.initFromBuffer( - this.pageValueCount, page.getDefinitionLevels().toByteArray()); + this.defColumn.initFromPage( + this.pageValueCount, page.getDefinitionLevels().toInputStream()); try { - initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0); + initDataReader(page.getDataEncoding(), page.getData().toInputStream()); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 5b75f719339fb..c62dc3d86386e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -20,8 +20,9 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.io.ParquetDecodingException; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; -import org.apache.spark.unsafe.Platform; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; @@ -30,24 +31,18 @@ * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. */ public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader { - private byte[] buffer; - private int offset; - private int bitOffset; // Only used for booleans. - private ByteBuffer byteBuffer; // used to wrap the byte array buffer + private ByteBufferInputStream in = null; - private static final boolean bigEndianPlatform = - ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + // Only used for booleans. + private int bitOffset; + private byte currentByte = 0; public VectorizedPlainValuesReader() { } @Override - public void initFromPage(int valueCount, byte[] bytes, int offset) throws IOException { - this.buffer = bytes; - this.offset = offset + Platform.BYTE_ARRAY_OFFSET; - if (bigEndianPlatform) { - byteBuffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN); - } + public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException { + this.in = in; } @Override @@ -63,115 +58,157 @@ public final void readBooleans(int total, WritableColumnVector c, int rowId) { } } + private ByteBuffer getBuffer(int length) { + try { + return in.slice(length).order(ByteOrder.LITTLE_ENDIAN); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read " + length + " bytes", e); + } + } + @Override public final void readIntegers(int total, WritableColumnVector c, int rowId) { - c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 4 * total; + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putIntsLittleEndian(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putInt(rowId + i, buffer.getInt()); + } + } } @Override public final void readLongs(int total, WritableColumnVector c, int rowId) { - c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 8 * total; + int requiredBytes = total * 8; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putLongsLittleEndian(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putLong(rowId + i, buffer.getLong()); + } + } } @Override public final void readFloats(int total, WritableColumnVector c, int rowId) { - c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 4 * total; + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putFloats(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putFloat(rowId + i, buffer.getFloat()); + } + } } @Override public final void readDoubles(int total, WritableColumnVector c, int rowId) { - c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 8 * total; + int requiredBytes = total * 8; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putDoubles(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putDouble(rowId + i, buffer.getDouble()); + } + } } @Override public final void readBytes(int total, WritableColumnVector c, int rowId) { - for (int i = 0; i < total; i++) { - // Bytes are stored as a 4-byte little endian int. Just read the first byte. - // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. - c.putByte(rowId + i, Platform.getByte(buffer, offset)); - offset += 4; + // Bytes are stored as a 4-byte little endian int. Just read the first byte. + // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + for (int i = 0; i < total; i += 1) { + c.putByte(rowId + i, buffer.get()); + // skip the next 3 bytes + buffer.position(buffer.position() + 3); } } @Override public final boolean readBoolean() { - byte b = Platform.getByte(buffer, offset); - boolean v = (b & (1 << bitOffset)) != 0; + // TODO: vectorize decoding and keep boolean[] instead of currentByte + if (bitOffset == 0) { + try { + currentByte = (byte) in.read(); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read a byte", e); + } + } + + boolean v = (currentByte & (1 << bitOffset)) != 0; bitOffset += 1; if (bitOffset == 8) { bitOffset = 0; - offset++; } return v; } @Override public final int readInteger() { - int v = Platform.getInt(buffer, offset); - if (bigEndianPlatform) { - v = java.lang.Integer.reverseBytes(v); - } - offset += 4; - return v; + return getBuffer(4).getInt(); } @Override public final long readLong() { - long v = Platform.getLong(buffer, offset); - if (bigEndianPlatform) { - v = java.lang.Long.reverseBytes(v); - } - offset += 8; - return v; + return getBuffer(8).getLong(); } @Override public final byte readByte() { - return (byte)readInteger(); + return (byte) readInteger(); } @Override public final float readFloat() { - float v; - if (!bigEndianPlatform) { - v = Platform.getFloat(buffer, offset); - } else { - v = byteBuffer.getFloat(offset - Platform.BYTE_ARRAY_OFFSET); - } - offset += 4; - return v; + return getBuffer(4).getFloat(); } @Override public final double readDouble() { - double v; - if (!bigEndianPlatform) { - v = Platform.getDouble(buffer, offset); - } else { - v = byteBuffer.getDouble(offset - Platform.BYTE_ARRAY_OFFSET); - } - offset += 8; - return v; + return getBuffer(8).getDouble(); } @Override public final void readBinary(int total, WritableColumnVector v, int rowId) { for (int i = 0; i < total; i++) { int len = readInteger(); - int start = offset; - offset += len; - v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len); + ByteBuffer buffer = getBuffer(len); + if (buffer.hasArray()) { + v.putByteArray(rowId + i, buffer.array(), buffer.arrayOffset() + buffer.position(), len); + } else { + byte[] bytes = new byte[len]; + buffer.get(bytes); + v.putByteArray(rowId + i, bytes); + } } } @Override public final Binary readBinary(int len) { - Binary result = Binary.fromConstantByteArray(buffer, offset - Platform.BYTE_ARRAY_OFFSET, len); - offset += len; - return result; + ByteBuffer buffer = getBuffer(len); + if (buffer.hasArray()) { + return Binary.fromConstantByteArray( + buffer.array(), buffer.arrayOffset() + buffer.position(), len); + } else { + byte[] bytes = new byte[len]; + buffer.get(bytes); + return Binary.fromConstantByteArray(bytes); + } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index fc7fa70c39419..fe3d31ae8e746 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import org.apache.parquet.Preconditions; +import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.column.values.bitpacking.BytePacker; @@ -27,6 +28,9 @@ import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import java.io.IOException; +import java.nio.ByteBuffer; + /** * A values reader for Parquet's run-length encoded data. This is based off of the version in * parquet-mr with these changes: @@ -49,9 +53,7 @@ private enum MODE { } // Encoded data. - private byte[] in; - private int end; - private int offset; + private ByteBufferInputStream in; // bit/byte width of decoded data and utility to batch unpack them. private int bitWidth; @@ -70,45 +72,40 @@ private enum MODE { // If true, the bit width is fixed. This decoder is used in different places and this also // controls if we need to read the bitwidth from the beginning of the data stream. private final boolean fixedWidth; + private final boolean readLength; public VectorizedRleValuesReader() { - fixedWidth = false; + this.fixedWidth = false; + this.readLength = false; } public VectorizedRleValuesReader(int bitWidth) { - fixedWidth = true; + this.fixedWidth = true; + this.readLength = bitWidth != 0; + init(bitWidth); + } + + public VectorizedRleValuesReader(int bitWidth, boolean readLength) { + this.fixedWidth = true; + this.readLength = readLength; init(bitWidth); } @Override - public void initFromPage(int valueCount, byte[] page, int start) { - this.offset = start; - this.in = page; + public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException { + this.in = in; if (fixedWidth) { - if (bitWidth != 0) { + // initialize for repetition and definition levels + if (readLength) { int length = readIntLittleEndian(); - this.end = this.offset + length; + this.in = in.sliceStream(length); } } else { - this.end = page.length; - if (this.end != this.offset) init(page[this.offset++] & 255); - } - if (bitWidth == 0) { - // 0 bit width, treat this as an RLE run of valueCount number of 0's. - this.mode = MODE.RLE; - this.currentCount = valueCount; - this.currentValue = 0; - } else { - this.currentCount = 0; + // initialize for values + if (in.available() > 0) { + init(in.read()); + } } - } - - // Initialize the reader from a buffer. This is used for the V2 page encoding where the - // definition are in its own buffer. - public void initFromBuffer(int valueCount, byte[] data) { - this.offset = 0; - this.in = data; - this.end = data.length; if (bitWidth == 0) { // 0 bit width, treat this as an RLE run of valueCount number of 0's. this.mode = MODE.RLE; @@ -129,11 +126,6 @@ private void init(int bitWidth) { this.packer = Packer.LITTLE_ENDIAN.newBytePacker(bitWidth); } - @Override - public int getNextOffset() { - return this.end; - } - @Override public boolean readBoolean() { return this.readInteger() != 0; @@ -182,7 +174,7 @@ public void readIntegers( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -217,7 +209,7 @@ public void readBooleans( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -251,7 +243,7 @@ public void readBytes( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -285,7 +277,7 @@ public void readShorts( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -321,7 +313,7 @@ public void readLongs( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -355,7 +347,7 @@ public void readFloats( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -389,7 +381,7 @@ public void readDoubles( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -423,7 +415,7 @@ public void readBinarys( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -462,7 +454,7 @@ public void readIntegers( WritableColumnVector nulls, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -559,12 +551,12 @@ public Binary readBinary(int len) { /** * Reads the next varint encoded int. */ - private int readUnsignedVarInt() { + private int readUnsignedVarInt() throws IOException { int value = 0; int shift = 0; int b; do { - b = in[offset++] & 255; + b = in.read(); value |= (b & 0x7F) << shift; shift += 7; } while ((b & 0x80) != 0); @@ -574,35 +566,32 @@ private int readUnsignedVarInt() { /** * Reads the next 4 byte little endian int. */ - private int readIntLittleEndian() { - int ch4 = in[offset] & 255; - int ch3 = in[offset + 1] & 255; - int ch2 = in[offset + 2] & 255; - int ch1 = in[offset + 3] & 255; - offset += 4; + private int readIntLittleEndian() throws IOException { + int ch4 = in.read(); + int ch3 = in.read(); + int ch2 = in.read(); + int ch1 = in.read(); return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); } /** * Reads the next byteWidth little endian int. */ - private int readIntLittleEndianPaddedOnBitWidth() { + private int readIntLittleEndianPaddedOnBitWidth() throws IOException { switch (bytesWidth) { case 0: return 0; case 1: - return in[offset++] & 255; + return in.read(); case 2: { - int ch2 = in[offset] & 255; - int ch1 = in[offset + 1] & 255; - offset += 2; + int ch2 = in.read(); + int ch1 = in.read(); return (ch1 << 8) + ch2; } case 3: { - int ch3 = in[offset] & 255; - int ch2 = in[offset + 1] & 255; - int ch1 = in[offset + 2] & 255; - offset += 3; + int ch3 = in.read(); + int ch2 = in.read(); + int ch1 = in.read(); return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); } case 4: { @@ -619,32 +608,36 @@ private int ceil8(int value) { /** * Reads the next group. */ - private void readNextGroup() { - int header = readUnsignedVarInt(); - this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; - switch (mode) { - case RLE: - this.currentCount = header >>> 1; - this.currentValue = readIntLittleEndianPaddedOnBitWidth(); - return; - case PACKED: - int numGroups = header >>> 1; - this.currentCount = numGroups * 8; - int bytesToRead = ceil8(this.currentCount * this.bitWidth); - - if (this.currentBuffer.length < this.currentCount) { - this.currentBuffer = new int[this.currentCount]; - } - currentBufferIdx = 0; - int valueIndex = 0; - for (int byteIndex = offset; valueIndex < this.currentCount; byteIndex += this.bitWidth) { - this.packer.unpack8Values(in, byteIndex, this.currentBuffer, valueIndex); - valueIndex += 8; - } - offset += bytesToRead; - return; - default: - throw new ParquetDecodingException("not a valid mode " + this.mode); + private void readNextGroup() { + try { + int header = readUnsignedVarInt(); + this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; + switch (mode) { + case RLE: + this.currentCount = header >>> 1; + this.currentValue = readIntLittleEndianPaddedOnBitWidth(); + return; + case PACKED: + int numGroups = header >>> 1; + this.currentCount = numGroups * 8; + + if (this.currentBuffer.length < this.currentCount) { + this.currentBuffer = new int[this.currentCount]; + } + currentBufferIdx = 0; + int valueIndex = 0; + while (valueIndex < this.currentCount) { + // values are bit packed 8 at a time, so reading bitWidth will always work + ByteBuffer buffer = in.slice(bitWidth); + this.packer.unpack8Values(buffer, buffer.position(), this.currentBuffer, valueIndex); + valueIndex += 8; + } + return; + default: + throw new ParquetDecodingException("not a valid mode " + this.mode); + } + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read from input stream", e); } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 4733f36174f42..6fdadde628551 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -216,12 +216,12 @@ protected UTF8String getBytesAsUTF8String(int rowId, int count) { @Override public void putShort(int rowId, short value) { - Platform.putShort(null, data + 2 * rowId, value); + Platform.putShort(null, data + 2L * rowId, value); } @Override public void putShorts(int rowId, int count, short value) { - long offset = data + 2 * rowId; + long offset = data + 2L * rowId; for (int i = 0; i < count; ++i, offset += 2) { Platform.putShort(null, offset, value); } @@ -229,20 +229,20 @@ public void putShorts(int rowId, int count, short value) { @Override public void putShorts(int rowId, int count, short[] src, int srcIndex) { - Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, - null, data + 2 * rowId, count * 2); + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2L, + null, data + 2L * rowId, count * 2L); } @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 2, count * 2); + null, data + rowId * 2L, count * 2L); } @Override public short getShort(int rowId) { if (dictionary == null) { - return Platform.getShort(null, data + 2 * rowId); + return Platform.getShort(null, data + 2L * rowId); } else { return (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -252,7 +252,7 @@ public short getShort(int rowId) { public short[] getShorts(int rowId, int count) { assert(dictionary == null); short[] array = new short[count]; - Platform.copyMemory(null, data + rowId * 2, array, Platform.SHORT_ARRAY_OFFSET, count * 2); + Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L); return array; } @@ -262,12 +262,12 @@ public short[] getShorts(int rowId, int count) { @Override public void putInt(int rowId, int value) { - Platform.putInt(null, data + 4 * rowId, value); + Platform.putInt(null, data + 4L * rowId, value); } @Override public void putInts(int rowId, int count, int value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putInt(null, offset, value); } @@ -275,24 +275,24 @@ public void putInts(int rowId, int count, int value) { @Override public void putInts(int rowId, int count, int[] src, int srcIndex) { - Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 4 * rowId, count * 4); + null, data + 4L * rowId, count * 4L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4, srcOffset += 4) { Platform.putInt(null, offset, java.lang.Integer.reverseBytes(Platform.getInt(src, srcOffset))); @@ -303,7 +303,7 @@ public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public int getInt(int rowId) { if (dictionary == null) { - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } else { return dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -313,7 +313,7 @@ public int getInt(int rowId) { public int[] getInts(int rowId, int count) { assert(dictionary == null); int[] array = new int[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.INT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L); return array; } @@ -325,7 +325,7 @@ public int[] getInts(int rowId, int count) { public int getDictId(int rowId) { assert(dictionary == null) : "A ColumnVector dictionary should not have a dictionary for itself."; - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } // @@ -334,12 +334,12 @@ public int getDictId(int rowId) { @Override public void putLong(int rowId, long value) { - Platform.putLong(null, data + 8 * rowId, value); + Platform.putLong(null, data + 8L * rowId, value); } @Override public void putLongs(int rowId, int count, long value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putLong(null, offset, value); } @@ -347,24 +347,24 @@ public void putLongs(int rowId, int count, long value) { @Override public void putLongs(int rowId, int count, long[] src, int srcIndex) { - Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 8 * rowId, count * 8); + null, data + 8L * rowId, count * 8L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8, srcOffset += 8) { Platform.putLong(null, offset, java.lang.Long.reverseBytes(Platform.getLong(src, srcOffset))); @@ -375,7 +375,7 @@ public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public long getLong(int rowId) { if (dictionary == null) { - return Platform.getLong(null, data + 8 * rowId); + return Platform.getLong(null, data + 8L * rowId); } else { return dictionary.decodeToLong(dictionaryIds.getDictId(rowId)); } @@ -385,7 +385,7 @@ public long getLong(int rowId) { public long[] getLongs(int rowId, int count) { assert(dictionary == null); long[] array = new long[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.LONG_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L); return array; } @@ -395,12 +395,12 @@ public long[] getLongs(int rowId, int count) { @Override public void putFloat(int rowId, float value) { - Platform.putFloat(null, data + rowId * 4, value); + Platform.putFloat(null, data + rowId * 4L, value); } @Override public void putFloats(int rowId, int count, float value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, value); } @@ -408,18 +408,18 @@ public void putFloats(int rowId, int count, float value) { @Override public void putFloats(int rowId, int count, float[] src, int srcIndex) { - Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, bb.getFloat(srcIndex + (4 * i))); } @@ -429,7 +429,7 @@ public void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public float getFloat(int rowId) { if (dictionary == null) { - return Platform.getFloat(null, data + rowId * 4); + return Platform.getFloat(null, data + rowId * 4L); } else { return dictionary.decodeToFloat(dictionaryIds.getDictId(rowId)); } @@ -439,7 +439,7 @@ public float getFloat(int rowId) { public float[] getFloats(int rowId, int count) { assert(dictionary == null); float[] array = new float[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.FLOAT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L); return array; } @@ -450,12 +450,12 @@ public float[] getFloats(int rowId, int count) { @Override public void putDouble(int rowId, double value) { - Platform.putDouble(null, data + rowId * 8, value); + Platform.putDouble(null, data + rowId * 8L, value); } @Override public void putDoubles(int rowId, int count, double value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, value); } @@ -463,18 +463,18 @@ public void putDoubles(int rowId, int count, double value) { @Override public void putDoubles(int rowId, int count, double[] src, int srcIndex) { - Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, bb.getDouble(srcIndex + (8 * i))); } @@ -484,7 +484,7 @@ public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public double getDouble(int rowId) { if (dictionary == null) { - return Platform.getDouble(null, data + rowId * 8); + return Platform.getDouble(null, data + rowId * 8L); } else { return dictionary.decodeToDouble(dictionaryIds.getDictId(rowId)); } @@ -494,7 +494,7 @@ public double getDouble(int rowId) { public double[] getDoubles(int rowId, int count) { assert(dictionary == null); double[] array = new double[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8L); return array; } @@ -504,26 +504,26 @@ public double[] getDoubles(int rowId, int count) { @Override public void putArray(int rowId, int offset, int length) { assert(offset >= 0 && offset + length <= childColumns[0].capacity); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, offset); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, offset); } @Override public int getArrayLength(int rowId) { - return Platform.getInt(null, lengthData + 4 * rowId); + return Platform.getInt(null, lengthData + 4L * rowId); } @Override public int getArrayOffset(int rowId) { - return Platform.getInt(null, offsetData + 4 * rowId); + return Platform.getInt(null, offsetData + 4L * rowId); } // APIs dealing with ByteArrays @Override public int putByteArray(int rowId, byte[] value, int offset, int length) { int result = arrayData().appendBytes(length, value, offset); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, result); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, result); return result; } @@ -533,19 +533,19 @@ protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; if (isArray() || type instanceof MapType) { this.lengthData = - Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4L, newCapacity * 4L); this.offsetData = - Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2L, newCapacity * 2L); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8L, newCapacity * 8L); } else if (childColumns != null) { // Nothing to store. } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 23dcc104e67c4..577eab6ed14c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -231,7 +231,7 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData, - Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2); + Platform.SHORT_ARRAY_OFFSET + rowId * 2L, count * 2L); } @Override @@ -276,7 +276,7 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData, - Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.INT_ARRAY_OFFSET + rowId * 4L, count * 4L); } @Override @@ -342,7 +342,7 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData, - Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.LONG_ARRAY_OFFSET + rowId * 8L, count * 8L); } @Override @@ -394,7 +394,7 @@ public void putFloats(int rowId, int count, float[] src, int srcIndex) { public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, floatData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { @@ -443,7 +443,7 @@ public void putDoubles(int rowId, int count, double[] src, int srcIndex) { public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5275e4a91eac0..b0e119d658cb4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -81,7 +81,9 @@ public void close() { } public void reserve(int requiredCapacity) { - if (requiredCapacity > capacity) { + if (requiredCapacity < 0) { + throwUnsupportedException(requiredCapacity, null); + } else if (requiredCapacity > capacity) { int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); if (requiredCapacity <= newCapacity) { try { @@ -96,13 +98,16 @@ public void reserve(int requiredCapacity) { } private void throwUnsupportedException(int requiredCapacity, Throwable cause) { - String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + - "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader, or increase the vectorized reader batch size. For parquet file " + - "format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " and " + - SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + "; for orc file format, refer to " + - SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + " and " + - SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + "."; + String message = "Cannot reserve additional contiguous bytes in the vectorized reader (" + + (requiredCapacity >= 0 ? "requested " + requiredCapacity + " bytes" : "integer overflow") + + "). As a workaround, you can reduce the vectorized reader batch size, or disable the " + + "vectorized reader. For parquet file format, refer to " + + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for orc file format, " + + "refer to " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + "."; throw new RuntimeException(message, cause); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java index 209ffa7a0b9fa..7f4a2c9593c76 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java @@ -34,7 +34,7 @@ public interface MicroBatchReadSupport extends DataSourceV2 { * streaming query. * * The execution engine will create a micro-batch reader at the start of a streaming query, - * alternate calls to setOffsetRange and createDataReaderFactories for each batch to process, and + * alternate calls to setOffsetRange and planInputPartitions for each batch to process, and * then call stop() when the execution is complete. Note that a single query may have multiple * executions due to restart or failure recovery. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index 0ea4dc6b5def3..b2526ded53d92 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -30,7 +30,7 @@ public interface ReadSupport extends DataSourceV2 { /** * Creates a {@link DataSourceReader} to scan the data from this data source. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param options the options for the returned data source reader, which is an immutable diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index 3801402268af1..f31659904cc53 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -35,7 +35,7 @@ public interface ReadSupportWithSchema extends DataSourceV2 { /** * Create a {@link DataSourceReader} to scan the data from this data source. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param schema the full schema of this data source reader. Full schema usually maps to the diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index cab56453816cc..83aeec0c47853 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -35,7 +35,7 @@ public interface WriteSupport extends DataSourceV2 { * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done according to the save mode. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param jobId A unique string for the writing job. It's possible that there are many writing diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java similarity index 67% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java index a61697649c43e..dcb87715d0b6f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java @@ -21,15 +21,15 @@ import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; /** - * A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can - * implement this interface to provide creating {@link DataReader} with particular offset. + * A mix-in interface for {@link InputPartition}. Continuous input partitions can + * implement this interface to provide creating {@link InputPartitionReader} with particular offset. */ @InterfaceStability.Evolving -public interface ContinuousDataReaderFactory extends DataReaderFactory { +public interface ContinuousInputPartition extends InputPartition { /** - * Create a DataReader with particular offset as its startOffset. + * Create an input partition reader with particular offset as its startOffset. * - * @param offset offset want to set as the DataReader's startOffset. + * @param offset offset want to set as the input partition reader's startOffset. */ - DataReader createDataReaderWithOffset(PartitionOffset offset); + InputPartitionReader createContinuousReader(PartitionOffset offset); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index a470bccc5aad2..36a3e542b5a11 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -31,8 +31,8 @@ * {@link ReadSupport#createReader(DataSourceOptions)} or * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link DataReaderFactory}s that are returned by - * {@link #createDataReaderFactories()}. + * logic is delegated to {@link InputPartition}s, which are returned by + * {@link #planInputPartitions()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column @@ -45,8 +45,8 @@ * only one of them would be respected, according to the priority list from high to low: * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. * - * If an exception was throw when applying any of these query optimizations, the action would fail - * and no Spark job was submitted. + * If an exception was throw when applying any of these query optimizations, the action will fail + * and no Spark job will be submitted. * * Spark first applies all operator push-down optimizations that this data source supports. Then * Spark collects information this data source reported for further optimizations. Finally Spark @@ -59,22 +59,22 @@ public interface DataSourceReader { * Returns the actual schema of this data source reader, which may be different from the physical * schema of the underlying storage, as column pruning or other optimizations may happen. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ StructType readSchema(); /** - * Returns a list of reader factories. Each factory is responsible for creating a data reader to - * output data for one RDD partition. That means the number of factories returned here is same as - * the number of RDD partitions this scan outputs. + * Returns a list of {@link InputPartition}s. Each {@link InputPartition} is responsible for + * creating a data reader to output data of one RDD partition. The number of input partitions + * returned here is the same as the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source reader mixes in other optimization * interfaces like column pruning, filter push-down, etc. These optimizations are applied before * Spark issues the scan request. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ - List> createDataReaderFactories(); + List> planInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java similarity index 60% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 32e98e8f5d8bd..f2038d0de3ffe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -22,29 +22,30 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is - * responsible for creating the actual data reader. The relationship between - * {@link DataReaderFactory} and {@link DataReader} + * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is + * responsible for creating the actual data reader of one RDD partition. + * The relationship between {@link InputPartition} and {@link InputPartitionReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that, the reader factory will be serialized and sent to executors, then the data reader - * will be created on executors and do the actual reading. So {@link DataReaderFactory} must be - * serializable and {@link DataReader} doesn't need to be. + * Note that {@link InputPartition}s will be serialized and sent to executors, then + * {@link InputPartitionReader}s will be created on executors to do the actual reading. So + * {@link InputPartition} must be serializable while {@link InputPartitionReader} doesn't need to + * be. */ @InterfaceStability.Evolving -public interface DataReaderFactory extends Serializable { +public interface InputPartition extends Serializable { /** - * The preferred locations where the data reader returned by this reader factory can run faster, - * but Spark does not guarantee to run the data reader on these locations. + * The preferred locations where the input partition reader returned by this partition can run + * faster, but Spark does not guarantee to run the input partition reader on these locations. * The implementations should make sure that it can be run on any location. * The location is a string representing the host name. * * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in - * the returned locations. By default this method returns empty string array, which means this - * task has no location preference. + * the returned locations. The default return value is empty string array, which means this + * input partition's reader has no location preference. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ default String[] preferredLocations() { @@ -52,10 +53,10 @@ default String[] preferredLocations() { } /** - * Returns a data reader to do the actual reading work. + * Returns an input partition reader to do the actual reading work. * * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. */ - DataReader createDataReader(); + InputPartitionReader createPartitionReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java similarity index 80% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index bb9790a1c819e..33fa7be4c1b20 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -23,15 +23,15 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for - * outputting data for a RDD partition. + * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is + * responsible for outputting data for a RDD partition. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data - * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source - * readers that mix in {@link SupportsScanUnsafeRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input + * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input + * partition readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving -public interface DataReader extends Closeable { +public interface InputPartitionReader extends Closeable { /** * Proceed to next record, returns false if there is no more records. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 290d614805ac7..4543c143a9aca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -39,7 +39,16 @@ public interface SupportsPushDownCatalystFilters extends DataSourceReader { Expression[] pushCatalystFilters(Expression[] filters); /** - * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. + * Returns the catalyst filters that are pushed to the data source via + * {@link #pushCatalystFilters(Expression[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for * this case. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 1cff024232a44..b6a90a3d0b681 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -37,7 +37,15 @@ public interface SupportsPushDownFilters extends DataSourceReader { Filter[] pushFilters(Filter[] filters); /** - * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. + * Returns the filters that are pushed to the data source via {@link #pushFilters(Filter[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} * is never called, empty array should be returned for this case. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 607628746e873..6b60da7c4dc1d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -24,7 +24,7 @@ * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report data partitioning and try to avoid shuffle at Spark side. * - * Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid + * Note that, when the reader creates exactly one {@link InputPartition}, Spark may avoid * adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 11bb13fd3b211..926396414816c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -22,6 +22,10 @@ /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. + * + * Statistics are reported to the optimizer before any operator is pushed to the DataSourceReader. + * Implementations that return more accurate statistics based on pushed operators will not improve + * query performance until the planner can push operators before getting stats. */ @InterfaceStability.Evolving public interface SupportsReportStatistics extends DataSourceReader { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 2e5cfa78511f0..0faf81db24605 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -30,22 +30,22 @@ @InterfaceStability.Evolving public interface SupportsScanColumnarBatch extends DataSourceReader { @Override - default List> createDataReaderFactories() { + default List> planInputPartitions() { throw new IllegalStateException( - "createDataReaderFactories not supported by default within SupportsScanColumnarBatch."); + "planInputPartitions not supported by default within SupportsScanColumnarBatch."); } /** - * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data + * Similar to {@link DataSourceReader#planInputPartitions()}, but returns columnar data * in batches. */ - List> createBatchDataReaderFactories(); + List> planBatchInputPartitions(); /** * Returns true if the concrete data source reader can read data in batch according to the scan * properties like required columns, pushes filters, etc. It's possible that the implementation * can only support some certain columns with certain types. Users can overwrite this method and - * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions. + * {@link #planInputPartitions()} to fallback to normal read path under some conditions. */ default boolean enableBatchRead() { return true; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index 9cd749e8e4ce9..f2220f6d31093 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -33,14 +33,14 @@ public interface SupportsScanUnsafeRow extends DataSourceReader { @Override - default List> createDataReaderFactories() { + default List> planInputPartitions() { throw new IllegalStateException( - "createDataReaderFactories not supported by default within SupportsScanUnsafeRow"); + "planInputPartitions not supported by default within SupportsScanUnsafeRow"); } /** - * Similar to {@link DataSourceReader#createDataReaderFactories()}, + * Similar to {@link DataSourceReader#planInputPartitions()}, * but returns data in unsafe row format. */ - List> createUnsafeRowReaderFactories(); + List> planUnsafeInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 2d0ee50212b56..38ca5fc6387b2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link DataReader}. + * {@link InputPartitionReader}. */ @InterfaceStability.Evolving public class ClusteredDistribution implements Distribution { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index f6b111fdf220d..5e32ba6952e1c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -18,13 +18,14 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions(one {@link DataReader} outputs data for one partition). + * be distributed among the data partitions (one {@link InputPartitionReader} outputs data for one + * partition). * Note that this interface has nothing to do with the data ordering inside one - * partition(the output records of a single {@link DataReader}). + * partition(the output records of a single {@link InputPartitionReader}). * * The instance of this interface is created and provided by Spark, then consumed by * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index 309d9e5de0a0f..f460f6bfe3bb9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** @@ -31,7 +31,7 @@ public interface Partitioning { /** - * Returns the number of partitions(i.e., {@link DataReaderFactory}s) the data source outputs. + * Returns the number of partitions(i.e., {@link InputPartition}s) the data source outputs. */ int numPartitions(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java similarity index 84% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java index 47d26440841fd..7b0ba0bbdda90 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java @@ -18,13 +18,13 @@ package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** - * A variation on {@link DataReader} for use with streaming in continuous processing mode. + * A variation on {@link InputPartitionReader} for use with streaming in continuous processing mode. */ @InterfaceStability.Evolving -public interface ContinuousDataReader extends DataReader { +public interface ContinuousInputPartitionReader extends InputPartitionReader { /** * Get the offset of the current record, or the start offset if no records have been read. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 7fe7f00ac2fa8..6e960bedf8020 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -27,7 +27,7 @@ * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to allow reading in a continuous processing mode stream. * - * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}. + * Implementations must ensure each partition reader is a {@link ContinuousInputPartitionReader}. * * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. @@ -35,8 +35,8 @@ @InterfaceStability.Evolving public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { /** - * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each - * partition to a single global offset. + * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances + * for each partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); @@ -47,7 +47,7 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader Offset deserializeOffset(String json); /** - * Set the desired start offset for reader factories created from this reader. The scan will + * Set the desired start offset for partitions created from this reader. The scan will * start from the first record after the provided offset, or from an implementation-defined * inferred starting point if no offset is provided. */ @@ -61,8 +61,8 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader Offset getStartOffset(); /** - * The execution engine will call this method in every epoch to determine if new reader - * factories need to be generated, which may be required if for example the underlying + * The execution engine will call this method in every epoch to determine if new input + * partitions need to be generated, which may be required if for example the underlying * source system has had partitions added or removed. * * If true, the query will be shut down and restarted with a new reader. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java index 67ebde30d61a9..0159c731762d9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java @@ -33,7 +33,7 @@ @InterfaceStability.Evolving public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { /** - * Set the desired offset range for reader factories created from this reader. Reader factories + * Set the desired offset range for input partitions created from this reader. Partition readers * will generate only data within (`start`, `end`]; that is, from the first record after `start` * to the record with offset `end`. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 0a0fd8db58035..7eedc85a5d6f3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -34,8 +34,8 @@ * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * - * If an exception was throw when applying any of these writing optimizations, the action would fail - * and no Spark job was submitted. + * If an exception was throw when applying any of these writing optimizations, the action will fail + * and no Spark job will be submitted. * * The writing procedure is: * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the @@ -58,14 +58,14 @@ public interface DataSourceWriter { /** * Creates a writer factory which will be serialized and sent to executors. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ DataWriterFactory createWriterFactory(); /** - * Returns whether Spark should use the commit coordinator to ensure that at most one attempt for - * each task commits. + * Returns whether Spark should use the commit coordinator to ensure that at most one task for + * each partition commits. * * @return true if commit coordinator should be used, false otherwise. */ @@ -90,9 +90,9 @@ default void onDataWriterCommit(WriterCommitMessage message) {} * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * * Note that speculative execution may cause multiple tasks to run for a partition. By default, - * Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can + * Spark uses the commit coordinator to allow at most one task to commit. Implementations can * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple - * attempts may have committed successfully and one successful commit message per task will be + * tasks may have committed successfully and one successful commit message per task will be * passed to this commit method. The remaining commit messages are ignored by Spark. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 39bf458298862..1626c0013e4e7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, long, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -39,14 +39,14 @@ * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark may retry this writing task a few times. - * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a - * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * In each retry, {@link DataWriterFactory#createDataWriter(int, long, long)} will receive a + * different `taskId`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the * previous one fails, speculative tasks are running simultaneously. It's possible that one input - * RDD partition has multiple data writers with different `attemptNumber` running at the same time, + * RDD partition has multiple data writers with different `taskId` running at the same time, * and data sources should guarantee that these data writers don't conflict and can work together. * Implementations can coordinate with driver during {@link #commit()} to make sure only one of * these data writers can commit successfully. Or implementations can allow all of them to commit diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index c2c2ab73257e8..0932ff8f8f8a7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -35,22 +35,19 @@ public interface DataWriterFactory extends Serializable { /** * Returns a data writer to do the actual writing work. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param partitionId A unique id of the RDD partition that the returned writer will process. * Usually Spark processes many RDD partitions at the same time, * implementations should use the partition id to distinguish writers for * different partitions. - * @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task - * failed, Spark launches a new task wth the same task id but different - * attempt number. Or a task is too slow, Spark launches new tasks wth the - * same task id but different attempt number, which means there are multiple - * tasks with the same task id running at the same time. Implementations can - * use this attempt number to distinguish writers of different task attempts. + * @param taskId A unique identifier for a task that is performing the write of the partition + * data. Spark may run multiple tasks for the same partition (due to speculation + * or task failures, for example). * @param epochId A monotonically increasing id for streaming queries that are split in to * discrete periods of execution. For non-streaming queries, * this ID will always be 0. */ - DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); + DataWriter createDataWriter(int partitionId, long taskId, long epochId); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ad0efbae89830..4eee3de5f7d4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.InterfaceStability @@ -103,7 +104,7 @@ class TypedColumn[-T, U]( * * {{{ * df("columnName") // On a specific `df` DataFrame. - * col("columnName") // A generic column no yet associated with a DataFrame. + * col("columnName") // A generic column not yet associated with a DataFrame. * col("columnName.field") // Extracting a struct field * col("`a.column.with.dots`") // Escape `.` in column names. * $"columnName" // Scala short hand for a named column. @@ -780,12 +781,54 @@ class Column(val expr: Expression) extends Logging { * A boolean expression that is evaluated to true if the value of this expression is contained * by the evaluated values of the arguments. * + * Note: Since the type of the elements in the list are inferred only during the run time, + * the elements will be "up-casted" to the most common type for comparison. + * For eg: + * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the + * comparison will look like "String vs String". + * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the + * comparison will look like "Double vs Double" + * * @group expr_ops * @since 1.5.0 */ @scala.annotation.varargs def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided collection. + * + * Note: Since the type of the elements in the collection are inferred only during the run time, + * the elements will be "up-casted" to the most common type for comparison. + * For eg: + * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the + * comparison will look like "String vs String". + * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the + * comparison will look like "Double vs Double" + * + * @group expr_ops + * @since 2.4.0 + */ + def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*) + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided collection. + * + * Note: Since the type of the elements in the collection are inferred only during the run time, + * the elements will be "up-casted" to the most common type for comparison. + * For eg: + * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the + * comparison will look like "String vs String". + * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the + * comparison will look like "Double vs Double" + * + * @group java_expr_ops + * @since 2.4.0 + */ + def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala) + /** * SQL like expression. Returns a boolean column based on a SQL LIKE match. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index d640fdc530ce2..ec9352a7fa055 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,6 +22,7 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ import com.fasterxml.jackson.databind.ObjectMapper +import com.univocity.parsers.csv.CsvParser import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability @@ -257,7 +258,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. "fetchsize" can be used to control the - * number of rows per fetch. + * number of rows per fetch and "queryTimeout" can be used to wait + * for a Statement object to execute to the given number of seconds. * @since 1.4.0 */ def jdbc( @@ -372,8 +374,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
    • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
    • + *
    • `encoding` (by default it is not set): allows to forcibly set one of standard basic + * or extended encoding for the JSON files. For example UTF-16BE, UTF-32LE. If the encoding + * is not specified and `multiLine` is set to `true`, it will be detected automatically.
    • *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
    • + *
    • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used + * for schema inferring.
    • + *
    • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
    • * * * @since 2.0.0 @@ -468,12 +477,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * it determines the columns as string types and it reads only the first line to determine the * names and the number of fields. * + * If the enforceSchema is set to `false`, only the CSV header in the first line is checked + * to conform specified or inferred schema. + * * @param csvDataset input Dataset with one CSV row per record * @since 2.2.0 */ def csv(csvDataset: Dataset[String]): DataFrame = { val parsedOptions: CSVOptions = new CSVOptions( extraOptions.toMap, + sparkSession.sessionState.conf.csvColumnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone) val filteredLines: Dataset[String] = CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) @@ -492,6 +505,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + CSVDataSource.checkHeader( + firstLine, + new CsvParser(parsedOptions.asParserSettings), + actualSchema, + csvDataset.getClass.getCanonicalName, + parsedOptions.enforceSchema, + sparkSession.sessionState.conf.caseSensitiveAnalysis) filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) @@ -532,8 +552,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `comment` (default empty string): sets a single character used for skipping lines * beginning with this character. By default, it is disabled.
    • *
    • `header` (default `false`): uses the first line as names of columns.
    • + *
    • `enforceSchema` (default `true`): If it is set to `true`, the specified or inferred schema + * will be forcibly applied to datasource files, and headers in CSV files will be ignored. + * If the option is set to `false`, the schema will be validated against all headers in CSV files + * in the case when the `header` option is set to `true`. Field names in the schema + * and column names in CSV headers are checked by their positions taking into account + * `spark.sql.caseSensitive`. Though the default value is true, it is recommended to disable + * the `enforceSchema` option to avoid incorrect results.
    • *
    • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
    • + *
    • `samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
    • *
    • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading * whitespaces from values being read should be skipped.
    • *
    • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing @@ -575,6 +603,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
    • *
    • `multiLine` (default `false`): parse one record, which may span multiple lines.
    • * + * * @since 2.0.0 */ @scala.annotation.varargs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index bbc063148a72c..90bea2d676e22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -330,8 +330,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def getBucketSpec: Option[BucketSpec] = { - if (sortColumnNames.isDefined) { - require(numBuckets.isDefined, "sortBy must be used together with bucketBy") + if (sortColumnNames.isDefined && numBuckets.isEmpty) { + throw new AnalysisException("sortBy must be used together with bucketBy") } numBuckets.map { n => @@ -340,8 +340,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def assertNotBucketed(operation: String): Unit = { - if (numBuckets.isDefined || sortColumnNames.isDefined) { - throw new AnalysisException(s"'$operation' does not support bucketing right now") + if (getBucketSpec.isDefined) { + if (sortColumnNames.isEmpty) { + throw new AnalysisException(s"'$operation' does not support bucketBy right now") + } else { + throw new AnalysisException(s"'$operation' does not support bucketBy and sortBy right now") + } } } @@ -518,8 +522,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
    • - *
    • `lineSep` (default `\n`): defines the line separator that should - * be used for writing.
    • + *
    • `encoding` (by default it is not set): specifies encoding (charset) of saved json + * files. If it is not set, the UTF-8 charset will be used.
    • + *
    • `lineSep` (default `\n`): defines the line separator that should be used for writing.
    • * * * @since 1.4.0 @@ -589,8 +594,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
    • - *
    • `lineSep` (default `\n`): defines the line separator that should - * be used for writing.
    • + *
    • `lineSep` (default `\n`): defines the line separator that should be used for writing.
    • * * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 917168162b236..c97246f30220d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -196,7 +196,7 @@ class Dataset[T] private[sql]( } // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - @transient private val planWithBarrier = AnalysisBarrier(logicalPlan) + @transient private[sql] val planWithBarrier = AnalysisBarrier(logicalPlan) /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the @@ -231,16 +231,15 @@ class Dataset[T] private[sql]( } /** - * Compose the string representing rows for output + * Get rows represented in Sequence by specific truncate and vertical requirement. * - * @param _numRows Number of rows to show + * @param numRows Number of rows to return * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. - * @param vertical If set to true, prints output rows vertically (one line per column value). */ - private[sql] def showString( - _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { - val numRows = _numRows.max(0).min(Int.MaxValue - 1) + private[sql] def getRows( + numRows: Int, + truncate: Int): Seq[Seq[String]] = { val newDf = toDF() val castCols = newDf.logicalPlan.output.map { col => // Since binary types in top-level schema fields have a specific format to print, @@ -251,14 +250,12 @@ class Dataset[T] private[sql]( Column(col).cast(StringType) } } - val takeResult = newDf.select(castCols: _*).take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) + val data = newDf.select(castCols: _*).take(numRows + 1) // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." - val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => + schema.fieldNames.toSeq +: data.map { row => row.toSeq.map { cell => val str = cell match { case null => "null" @@ -274,6 +271,26 @@ class Dataset[T] private[sql]( } }: Seq[String] } + } + + /** + * Compose the string representing rows for output + * + * @param _numRows Number of rows to show + * @param truncate If set to more than 0, truncates strings to `truncate` characters and + * all cells will be aligned right. + * @param vertical If set to true, prints output rows vertically (one line per column value). + */ + private[sql] def showString( + _numRows: Int, + truncate: Int = 20, + vertical: Boolean = false): String = { + val numRows = _numRows.max(0).min(Int.MaxValue - 1) + // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data. + val tmpRows = getRows(numRows, truncate) + + val hasMoreData = tmpRows.length - 1 > numRows + val rows = tmpRows.take(numRows + 1) val sb = new StringBuilder val numCols = schema.fieldNames.length @@ -291,31 +308,25 @@ class Dataset[T] private[sql]( } } + val paddedRows = rows.map { row => + row.zipWithIndex.map { case (cell, i) => + if (truncate > 0) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } + } + } + // Create SeparateLine val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - + paddedRows.head.addString(sb, "|", "|", "|\n") sb.append(sep) // data - rows.tail.foreach { - _.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell.toString, colWidths(i)) - } else { - StringUtils.rightPad(cell.toString, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - } - + paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n")) sb.append(sep) } else { // Extended display mode enabled @@ -346,7 +357,7 @@ class Dataset[T] private[sql]( } // Print a footer - if (vertical && data.isEmpty) { + if (vertical && rows.tail.isEmpty) { // In a vertical mode, print an empty row set explicitly sb.append("(0 rows)\n") } else if (hasMoreData) { @@ -511,6 +522,16 @@ class Dataset[T] private[sql]( */ def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] + /** + * Returns true if the `Dataset` is empty. + * + * @group basic + * @since 2.4.0 + */ + def isEmpty: Boolean = withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan => + plan.executeCollect().head.getLong(0) == 0 + } + /** * Returns true if this Dataset contains one or more sources that continuously * return data as it arrives. A Dataset that reads data from a streaming source @@ -995,6 +1016,11 @@ class Dataset[T] private[sql]( catalyst.expressions.EqualTo( withPlan(plan.left).resolve(a.name), withPlan(plan.right).resolve(b.name)) + case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualNullSafe( + withPlan(plan.left).resolve(a.name), + withPlan(plan.right).resolve(b.name)) }} withPlan { @@ -1607,7 +1633,9 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def reduce(func: (T, T) => T): T = rdd.reduce(func) + def reduce(func: (T, T) => T): T = withNewRDDExecutionId { + rdd.reduce(func) + } /** * :: Experimental :: @@ -2933,12 +2961,13 @@ class Dataset[T] private[sql]( */ def storageLevel: StorageLevel = { sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData => - cachedData.cachedRepresentation.storageLevel + cachedData.cachedRepresentation.cacheBuilder.storageLevel }.getOrElse(StorageLevel.NONE) } /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. + * This will not un-persist any cached data that is built upon this Dataset. * * @param blocking Whether to block until all blocks are deleted. * @@ -2946,12 +2975,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def unpersist(blocking: Boolean): this.type = { - sparkSession.sharedState.cacheManager.uncacheQuery(this, blocking) + sparkSession.sharedState.cacheManager.uncacheQuery(this, cascade = false, blocking) this } /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. + * This will not un-persist any cached data that is built upon this Dataset. * * @group basic * @since 1.6.0 @@ -3187,7 +3217,7 @@ class Dataset[T] private[sql]( EvaluatePython.javaToPython(rdd) } - private[sql] def collectToPython(): Int = { + private[sql] def collectToPython(): Array[Any] = { EvaluatePython.registerPicklers() withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) @@ -3197,10 +3227,22 @@ class Dataset[T] private[sql]( } } + private[sql] def getRowsToPython( + _numRows: Int, + truncate: Int): Array[Any] = { + EvaluatePython.registerPicklers() + val numRows = _numRows.max(0).min(Int.MaxValue - 1) + val rows = getRows(numRows, truncate).map(_.toArray).toArray + val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType))) + val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( + rows.iterator.map(toJava)) + PythonRDD.serveIterator(iter, "serve-GetRows") + } + /** * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ - private[sql] def collectAsArrowToPython(): Int = { + private[sql] def collectAsArrowToPython(): Array[Any] = { withAction("collectAsArrowToPython", queryExecution) { plan => val iter: Iterator[Array[Byte]] = toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) @@ -3208,7 +3250,7 @@ class Dataset[T] private[sql]( } } - private[sql] def toPythonIterator(): Int = { + private[sql] def toPythonIterator(): Array[Any] = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 86e02e98c01f3..b21c50af18433 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -20,10 +20,48 @@ package org.apache.spark.sql import org.apache.spark.annotation.InterfaceStability /** - * A class to consume data generated by a `StreamingQuery`. Typically this is used to send the - * generated data to external systems. Each partition will use a new deserialized instance, so you - * usually should do all the initialization (e.g. opening a connection or initiating a transaction) - * in the `open` method. + * The abstract class for writing custom logic to process data generated by a query. + * This is often used to write the output of a streaming query to arbitrary storage systems. + * Any implementation of this base class will be used by Spark in the following way. + * + *
        + *
      • A single instance of this class is responsible of all the data generated by a single task + * in a query. In other words, one instance is responsible for processing one partition of the + * data generated in a distributed manner. + * + *
      • Any implementation of this class must be serializable because each task will get a fresh + * serialized-deserialized copy of the provided object. Hence, it is strongly recommended that + * any initialization for writing data (e.g. opening a connection or starting a transaction) + * is done after the `open(...)` method has been called, which signifies that the task is + * ready to generate data. + * + *
      • The lifecycle of the methods are as follows. + * + *
        + *   For each partition with `partitionId`:
        + *       For each batch/epoch of streaming data (if its streaming query) with `epochId`:
        + *           Method `open(partitionId, epochId)` is called.
        + *           If `open` returns true:
        + *                For each row in the partition and batch/epoch, method `process(row)` is called.
        + *           Method `close(errorOrNull)` is called with error (if any) seen while processing rows.
        + *   
        + * + *
      + * + * Important points to note: + *
        + *
      • The `partitionId` and `epochId` can be used to deduplicate generated data when failures + * cause reprocessing of some input data. This depends on the execution mode of the query. If + * the streaming query is being executed in the micro-batch mode, then every partition + * represented by a unique tuple (partitionId, epochId) is guaranteed to have the same data. + * Hence, (partitionId, epochId) can be used to deduplicate and/or transactionally commit data + * and achieve exactly-once guarantees. However, if the streaming query is being executed in the + * continuous mode, then this guarantee does not hold and therefore should not be used for + * deduplication. + * + *
      • The `close()` method will be called if `open()` method returns successfully (irrespective + * of the return value), except if the JVM crashes in the middle. + *
      * * Scala example: * {{{ @@ -63,6 +101,7 @@ import org.apache.spark.annotation.InterfaceStability * } * }); * }}} + * * @since 2.0.0 */ @InterfaceStability.Evolving @@ -71,23 +110,18 @@ abstract class ForeachWriter[T] extends Serializable { // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. /** - * Called when starting to process one partition of new data in the executor. The `version` is - * for data deduplication when there are failures. When recovering from a failure, some data may - * be generated multiple times but they will always have the same version. - * - * If this method finds using the `partitionId` and `version` that this partition has already been - * processed, it can return `false` to skip the further data processing. However, `close` still - * will be called for cleaning up resources. + * Called when starting to process one partition of new data in the executor. See the class + * docs for more information on how to use the `partitionId` and `epochId`. * * @param partitionId the partition id. - * @param version a unique id for data deduplication. + * @param epochId a unique id for data deduplication. * @return `true` if the corresponding partition and version id should be processed. `false` * indicates the partition should be skipped. */ - def open(partitionId: Long, version: Long): Boolean + def open(partitionId: Long, epochId: Long): Boolean /** - * Called to process the data in the executor side. This method will be called only when `open` + * Called to process the data in the executor side. This method will be called only if `open` * returns `true`. */ def process(value: T): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 6bab21dca0cbd..36f6038aa9485 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -49,7 +49,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private implicit val kExprEnc = encoderFor(kEncoder) private implicit val vExprEnc = encoderFor(vEncoder) - private def logicalPlan = queryExecution.analyzed + private def logicalPlan = AnalysisBarrier(queryExecution.analyzed) private def sparkSession = queryExecution.sparkSession /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7147798d99533..c6449cd5a16b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -63,17 +63,17 @@ class RelationalGroupedDataset protected[sql]( groupType match { case RelationalGroupedDataset.GroupByType => Dataset.ofRows( - df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.CubeType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.planWithBarrier)) } } @@ -433,7 +433,7 @@ class RelationalGroupedDataset protected[sql]( df.exprEnc.schema, groupingAttributes, df.logicalPlan.output, - df.logicalPlan)) + df.planWithBarrier)) } /** @@ -459,7 +459,7 @@ class RelationalGroupedDataset protected[sql]( case other => Alias(other, other.toString)() } val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) - val child = df.logicalPlan + val child = df.planWithBarrier val project = Project(groupingNamedExpressions ++ child.output, child) val output = expr.dataType.asInstanceOf[StructType].toAttributes val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index b352e332bc7e0..3c39579149fff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -132,6 +132,17 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { sqlConf.unsetConf(key) } + /** + * Indicates whether the configuration property with the given key + * is modifiable in the current session. + * + * @return `true` if the configuration property is modifiable. For static SQL, Spark Core, + * invalid (not existing) and other non-modifiable configuration properties, + * the returned value is `false`. + * @since 2.4.0 + */ + def isModifiable(key: String): Boolean = sqlConf.isModifiable(key) + /** * Returns whether a particular key is set. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..565042fcf762e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,6 +898,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1020,16 +1021,34 @@ object SparkSession extends Logging { /** * Returns the active SparkSession for the current thread, returned by the builder. * + * @note Return None, when calling this function on executors + * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + if (TaskContext.get != null) { + // Return None when running on executors. + None + } else { + Option(activeThreadSession.get) + } + } /** * Returns the default SparkSession that is returned by the builder. * + * @note Return None, when calling this function on executors + * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + if (TaskContext.get != null) { + // Return None when running on executors. + None + } else { + Option(defaultSession.get) + } + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1081,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index a8794be7280c7..39d9a95ca4710 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -71,7 +71,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cacheBuilder.clearCache()) cachedData.clear() } @@ -97,7 +97,7 @@ class CacheManager extends Logging { val inMemoryRelation = InMemoryRelation( sparkSession.sessionState.conf.useCompression, sparkSession.sessionState.conf.columnBatchSize, storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, + sparkSession.sessionState.executePlan(AnalysisBarrier(planToCache)).executedPlan, tableName, planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) @@ -105,24 +105,50 @@ class CacheManager extends Logging { } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the given plan or all the cache entries that refer to the given plan. + * @param query The [[Dataset]] to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * [[Dataset]]; otherwise un-cache the given [[Dataset]] only. + * @param blocking Whether to block until all blocks are deleted. */ - def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { - uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + def uncacheQuery( + query: Dataset[_], + cascade: Boolean, + blocking: Boolean = true): Unit = writeLock { + uncacheQuery(query.sparkSession, query.logicalPlan, cascade, blocking) } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the given plan or all the cache entries that refer to the given plan. + * @param spark The Spark session. + * @param plan The plan to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * plan; otherwise un-cache the given plan only. + * @param blocking Whether to block until all blocks are deleted. */ - def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { + def uncacheQuery( + spark: SparkSession, + plan: LogicalPlan, + cascade: Boolean, + blocking: Boolean): Unit = writeLock { + val shouldRemove: LogicalPlan => Boolean = + if (cascade) { + _.find(_.sameResult(plan)).isDefined + } else { + _.sameResult(plan) + } val it = cachedData.iterator() while (it.hasNext) { val cd = it.next() - if (cd.plan.find(_.sameResult(plan)).isDefined) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + if (shouldRemove(cd.plan)) { + cd.cachedRepresentation.cacheBuilder.clearCache(blocking) it.remove() } } + // Re-compile dependent cached queries after removing the cached query. + if (!cascade) { + recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined, clearCache = false) + } } /** @@ -132,22 +158,24 @@ class CacheManager extends Logging { recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined) } - private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = { + private def recacheByCondition( + spark: SparkSession, + condition: LogicalPlan => Boolean, + clearCache: Boolean = true): Unit = { val it = cachedData.iterator() val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData] while (it.hasNext) { val cd = it.next() if (condition(cd.plan)) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist() + if (clearCache) { + cd.cachedRepresentation.cacheBuilder.clearCache() + } // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() + val plan = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan val newCache = InMemoryRelation( - useCompression = cd.cachedRepresentation.useCompression, - batchSize = cd.cachedRepresentation.batchSize, - storageLevel = cd.cachedRepresentation.storageLevel, - child = spark.sessionState.executePlan(cd.plan).executedPlan, - tableName = cd.cachedRepresentation.tableName, + cacheBuilder = cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan), logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index fc3dbc1c5591b..48abad9078650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -58,14 +59,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" - val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { - s""" + val code = code"${ctx.registerComment(str)}" + (if (nullable) { + code""" boolean $isNullVar = $columnVar.isNullAt($ordinal); $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { - s"$javaType $valueVar = $value;" - }).trim + code"$javaType $valueVar = $value;" + }) ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 08ff33afbba3d..d7f2654be0451 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils +import org.apache.spark.util.collection.BitSet trait DataSourceScanExec extends LeafExecNode with CodegenSupport { val relation: BaseRelation @@ -69,7 +70,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { * Shorthand for calling redactString() without specifying redacting rules */ private def redact(text: String): String = { - Utils.redact(sqlContext.sessionState.conf.stringRedationPattern, text) + Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text) } } @@ -151,6 +152,7 @@ case class RowDataSourceScanExec( * @param output Output attributes of the scan, including data attributes and partition attributes. * @param requiredSchema Required schema of the underlying relation, excluding partition columns. * @param partitionFilters Predicates to use for partition pruning. + * @param optionalBucketSet Bucket ids for bucket pruning * @param dataFilters Filters on non-partition columns. * @param tableIdentifier identifier for the table in the metastore. */ @@ -159,6 +161,7 @@ case class FileSourceScanExec( output: Seq[Attribute], requiredSchema: StructType, partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], dataFilters: Seq[Expression], override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { @@ -286,7 +289,20 @@ case class FileSourceScanExec( } getOrElse { metadata } - withOptPartitionCount + + val withSelectedBucketsCount = relation.bucketSpec.map { spec => + val numSelectedBuckets = optionalBucketSet.map { b => + b.cardinality() + } getOrElse { + spec.numBuckets + } + withOptPartitionCount + ("SelectedBucketsCount" -> + s"$numSelectedBuckets out of ${spec.numBuckets}") + } getOrElse { + withOptPartitionCount + } + + withSelectedBucketsCount } private lazy val inputRDD: RDD[InternalRow] = { @@ -365,7 +381,7 @@ case class FileSourceScanExec( selectedPartitions: Seq[PartitionDirectory], fsRelation: HadoopFsRelation): RDD[InternalRow] = { logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") - val bucketed = + val filesGroupedToBuckets = selectedPartitions.flatMap { p => p.files.map { f => val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen) @@ -377,8 +393,17 @@ case class FileSourceScanExec( .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) } + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { + f => bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId => - FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil)) + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Nil)) } new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) @@ -503,6 +528,7 @@ case class FileSourceScanExec( output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, QueryPlan.normalizePredicates(partitionFilters, output), + optionalBucketSet, QueryPlan.normalizePredicates(dataFilters, output), None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index e4812f3d338fb..5b4edf5136e3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -152,7 +153,7 @@ case class ExpandExec( } else { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val code = s""" + val code = code""" |boolean $isNull = true; |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f40c50df74ccb..2549b9e1537a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types._ @@ -313,13 +314,13 @@ case class GenerateExec( if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) } else { - ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) + ExprCode(code"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 514ad7018d8c7..448eb703eacde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.execution.metric.SQLMetrics /** * Physical plan node for scanning data from a local collection. + * + * `Seq` may not be serializable and ideally we should not send `rows` and `unsafeRows` + * to the executors. Thus marking them as transient. */ case class LocalTableScanExec( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index acbd4becb8549..3ca03ab2939aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} @@ -49,9 +50,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(_, attrs, filters, rel)) => + case a @ Aggregate(_, aggExprs, child @ PhysicalOperation( + projectList, filters, PartitionedRelation(partAttrs, rel))) => // We only apply this optimization when only partitioned attributes are scanned. - if (a.references.subsetOf(attrs)) { + if (AttributeSet((projectList ++ filters).flatMap(_.references)).subsetOf(partAttrs)) { + // The project list and filters all only refer to partition attributes, which means the + // the Aggregator operator can also only refer to partition attributes, and filters are + // all partition filters. This is a metadata only query we can optimize. val aggFunctions = aggExprs.flatMap(_.collect { case agg: AggregateExpression => agg }) @@ -102,7 +107,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic partFilters: Seq[Expression]): LogicalPlan = { // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the // relation's schema. PartitionedRelation ensures that the filters only reference partition cols - val relFilters = partFilters.map { e => + val normalizedFilters = partFilters.map { e => e transform { case a: AttributeReference => a.withName(relation.output.find(_.semanticEquals(a)).get.name) @@ -114,11 +119,8 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(relFilters, Nil) - // partition data may be a stream, which can cause serialization to hit stack level too - // deep exceptions because it is a recursive structure in memory. converting to array - // avoids the problem. - LocalRelation(partAttrs, partitionData.map(_.values).toArray, isStreaming) + val partitionData = fsRelation.location.listFiles(normalizedFilters, Nil) + LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) @@ -127,7 +129,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(SQLConf.get.sessionLocalTimeZone) val partitions = if (partFilters.nonEmpty) { - catalog.listPartitionsByFilter(relation.tableMeta.identifier, relFilters) + catalog.listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters) } else { catalog.listPartitions(relation.tableMeta.identifier) } @@ -137,10 +139,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } - // partition data may be a stream, which can cause serialization to hit stack level too - // deep exceptions because it is a recursive structure in memory. converting to array - // avoids the problem. - LocalRelation(partAttrs, partitionData.toArray) + LocalRelation(partAttrs, partitionData) case _ => throw new IllegalStateException(s"unrecognized table scan node: $relation, " + @@ -151,44 +150,21 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic /** * A pattern that finds the partitioned table relation node inside the given plan, and returns a - * pair of the partition attributes, partition filters, and the table relation node. - * - * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with - * deterministic expressions, and returns result after reaching the partitioned table relation - * node. + * pair of the partition attributes and the table relation node. */ object PartitionedRelation extends PredicateHelper { - def unapply( - plan: LogicalPlan): Option[(AttributeSet, AttributeSet, Seq[Expression], LogicalPlan)] = { + def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = { plan match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) - if fsRelation.partitionSchema.nonEmpty => + if fsRelation.partitionSchema.nonEmpty => val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)) - Some((partAttrs, partAttrs, Nil, l)) + Some((partAttrs, l)) case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => val partAttrs = AttributeSet( getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)) - Some((partAttrs, partAttrs, Nil, relation)) - - case p @ Project(projectList, child) if projectList.forall(_.deterministic) => - unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => - if (p.references.subsetOf(attrs)) { - Some((partAttrs, p.outputSet, filters, relation)) - } else { - None - } - } - - case f @ Filter(condition, child) if condition.deterministic => - unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => - if (f.references.subsetOf(partAttrs)) { - Some((partAttrs, attrs, splitConjunctivePredicates(condition) ++ filters, relation)) - } else { - None - } - } + Some((partAttrs, relation)) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 15379a0663f7d..3112b306c365e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -225,7 +225,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { * Redact the sensitive information in the given string. */ private def withRedaction(message: String): String = { - Utils.redact(sparkSession.sessionState.conf.stringRedationPattern, message) + Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, message) } /** A special namespace for commands that can be used to debug query execution. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index e991da7df0bde..439932b0cc3ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,16 +68,18 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() + val callSite = sc.getCallSite() - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + withSQLConfPropagated(sparkSession) { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sc.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + } } } finally { executionIdToQueryExecution.remove(executionId) @@ -88,15 +90,43 @@ object SQLExecution { /** * Wrap an action with a known executionId. When running a different action in a different * thread from the original one, this method can be used to connect the Spark jobs in this action - * with the known executionId, e.g., `BroadcastHashJoin.broadcastFuture`. + * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { + def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { + val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + withSQLConfPropagated(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } + } + } + + /** + * Wrap an action with specified SQL configs. These configs will be propagated to the executor + * side via job local properties. + */ + def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + // Set all the specified SQL configs to local properties, so that they can be available at + // the executor side. + val allConfigs = sparkSession.sessionState.conf.getAllConfs + val originalLocalProps = allConfigs.collect { + case (key, value) if key.startsWith("spark") => + val originalValue = sc.getLocalProperty(key) + sc.setLocalProperty(key, value) + (key, originalValue) + } + try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + for ((key, value) <- originalLocalProps) { + sc.setLocalProperty(key, value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 1c8e4050978dc..00ff4c8ac310b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions -import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( @@ -32,8 +31,7 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ - Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 74048871f8d42..75f5ec0e253df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -41,6 +41,7 @@ class SparkPlanner( DataSourceStrategy(conf) :: SpecialLimits :: Aggregation :: + Window :: JoinSelection :: InMemoryScans :: BasicOperators :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 82b4eb9fba242..02e095b42a506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -34,7 +35,7 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2 import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery} import org.apache.spark.sql.types.StructType /** @@ -66,22 +67,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object SpecialLimits extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => - // With whole stage codegen, Spark releases resources only when all the output data of the - // query plan are consumed. It's possible that `CollectLimitExec` only consumes a little - // data from child plan and finishes the query without releasing resources. Here we wrap - // the child plan with `LocalLimitExec`, to stop the processing of whole stage codegen and - // trigger the resource releasing work, after we consume `limit` rows. - CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } - case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case _ => Nil } @@ -323,7 +323,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => - if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + if (aggregateExpressions.exists(PythonUDF.isGroupedAggPandasUDF)) { throw new AnalysisException( "Streaming aggregation doesn't support group aggregate pandas UDF") } @@ -350,6 +350,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Used to plan the streaming global limit operator for streams in append mode. + * We need to check for either a direct Limit or a Limit wrapped in a ReturnAnswer operator, + * following the example of the SpecialLimits Strategy above. + * Streams with limit in Append mode use the stateful StreamingGlobalLimitExec. + * Streams with limit in Complete mode use the stateless CollectLimitExec operator. + * Limit is unsupported for streams in Update mode. + */ + case class StreamingGlobalLimitStrategy(outputMode: OutputMode) extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ReturnAnswer(rootPlan) => rootPlan match { + case Limit(IntegerLiteral(limit), child) + if plan.isStreaming && outputMode == InternalOutputModes.Append => + StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case _ => Nil + } + case Limit(IntegerLiteral(limit), child) + if plan.isStreaming && outputMode == InternalOutputModes.Append => + StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case _ => Nil + } + } + object StreamingJoinStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = { plan match { @@ -361,7 +384,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case Join(left, right, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( - "Stream stream joins without equality predicate is not supported", plan = Some(plan)) + "Stream-stream join without equality predicate is not supported", plan = Some(plan)) case _ => Nil } @@ -380,9 +403,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our MultipleDistinctRewriter should take care this case. + // column sets. Our `RewriteDistinctAggregates` should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + "Spark user mailing list.") } @@ -424,6 +447,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object Window extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalWindow( + WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) => + execution.window.WindowExec( + windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + + case PhysicalWindow( + WindowFunctionType.Python, windowExprs, partitionSpec, orderSpec, child) => + execution.python.WindowInPandasExec( + windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + + case _ => Nil + } + } + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object InMemoryScans extends Strategy { @@ -474,7 +513,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil @@ -544,8 +582,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil - case logical.Window(windowExprs, partitionSpec, orderSpec, child) => - execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 828b51fa199de..372dc3db36ce6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -122,10 +123,10 @@ trait CodegenSupport extends SparkPlan { ctx.INPUT_ROW = row ctx.currentVars = colVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - val code = s""" + val code = code""" |$evaluateInputs - |${ev.code.trim} - """.stripMargin.trim + |${ev.code} + """.stripMargin ExprCode(code, FalseLiteral, ev.value) } else { // There are no columns @@ -259,8 +260,8 @@ trait CodegenSupport extends SparkPlan { * them to be evaluated twice. */ protected def evaluateVariables(variables: Seq[ExprCode]): String = { - val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") - variables.foreach(_.code = "") + val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n") + variables.foreach(_.code = EmptyBlock) evaluate } @@ -275,8 +276,8 @@ trait CodegenSupport extends SparkPlan { val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => if (ev.code != "" && required.contains(attributes(i))) { - evaluateVars.append(ev.code.trim + "\n") - ev.code = "" + evaluateVars.append(ev.code.toString + "\n") + ev.code = EmptyBlock } } evaluateVars.toString() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a5dc6ebf2b0f2..2cac0cfce28de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -190,7 +191,7 @@ case class HashAggregateExec( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) - val initVars = s""" + val initVars = code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin @@ -327,7 +328,7 @@ case class HashAggregateExec( initialBuffer, bufferSchema, groupingKeySchema, - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) @@ -755,7 +756,10 @@ case class HashAggregateExec( } // generate hash code for key - val hashExpr = Murmur3Hash(groupingExpressions, 42) + // SPARK-24076: HashAggregate uses the same hash algorithm on the same expressions + // as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n, + // pick a different seed to avoid this conflict + val hashExpr = Murmur3Hash(groupingExpressions, 48) val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, @@ -770,8 +774,8 @@ case class HashAggregateExec( val findOrInsertRegularHashMap: String = s""" |// generate grouping key - |${unsafeRowKeyCode.code.trim} - |${hashEval.code.trim} + |${unsafeRowKeyCode.code} + |${hashEval.code} |if ($checkFallbackForBytesToBytesMap) { | // try to get the buffer from hash map | $unsafeRowBuffer = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index de2d630de3fdb..e1c85823259b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -50,7 +51,7 @@ abstract class HashMapGenerator( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = - s""" + code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 9dc334c1ead3c..c1911235f8df3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -166,7 +166,7 @@ class TungstenAggregationIterator( initialAggregationBuffer, StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index aab8cc50b9526..6d44890704f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object TypedAggregateExpression { def apply[BUF : Encoder, OUT : Encoder]( @@ -109,7 +110,9 @@ trait TypedAggregateExpression extends AggregateFunction { s"$nodeName($input)" } - override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") + // aggregator.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + override def nodeName: String = Utils.getSimpleName(aggregator.getClass).stripSuffix("$"); } // TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 6ad11bda84bf6..93c8127681b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -23,6 +23,7 @@ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ object ArrowUtils { @@ -120,4 +121,19 @@ object ArrowUtils { StructField(field.getName, dt, field.isNullable) }) } + + /** Return Map with conf settings to be used in ArrowPythonRunner */ + def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { + val timeZoneConf = if (conf.pandasRespectSessionTimeZone) { + Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) + } else { + Nil + } + val pandasColsByPosition = if (conf.pandasGroupedMapAssignColumnssByPosition) { + Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION.key -> "true") + } else { + Nil + } + Map(timeZoneConf ++ pandasColsByPosition: _*) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 22b63513548fe..66888fce7f9f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -133,6 +133,14 @@ private[arrow] abstract class ArrowFieldWriter { valueVector match { case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset() case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset() + case listVector: ListVector => + // Manual "reset" the underlying buffer. + // TODO: When we upgrade to Arrow 0.10.0, we can simply remove this and call + // `listVector.reset()`. + val buffers = listVector.getBuffers(false) + buffers.foreach(buf => buf.setZero(0, buf.capacity())) + listVector.setValueCount(0) + listVector.setLastSet(0) case _ => } count = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd8..9434ceb7cd16c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -345,6 +345,20 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) override val output: Seq[Attribute] = range.output + override def outputOrdering: Seq[SortOrder] = range.outputOrdering + + override def outputPartitioning: Partitioning = { + if (numElements > 0) { + if (numSlices == 1) { + SinglePartition + } else { + RangePartitioning(outputOrdering, numSlices) + } + } else { + UnknownPartitioning(0) + } + } + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -629,7 +643,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index a7ba9b86a176f..7c8faec53a828 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -32,19 +32,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator -object InMemoryRelation { - def apply( - useCompression: Boolean, - batchSize: Int, - storageLevel: StorageLevel, - child: SparkPlan, - tableName: Option[String], - logicalPlan: LogicalPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) -} - - /** * CachedBatch is a cached batch of rows. * @@ -55,58 +42,51 @@ object InMemoryRelation { private[columnar] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) -case class InMemoryRelation( - output: Seq[Attribute], +case class CachedRDDBuilder( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - @transient child: SparkPlan, + @transient cachedPlan: SparkPlan, tableName: Option[String])( - @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics, - override val outputOrdering: Seq[SortOrder]) - extends logical.LeafNode with MultiInstanceRelation { + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null) { - override protected def innerChildren: Seq[SparkPlan] = Seq(child) + val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator - override def doCanonicalize(): logical.LogicalPlan = - copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)), - storageLevel = StorageLevel.NONE, - child = child.canonicalized, - tableName = None)( - _cachedColumnBuffers, - sizeInBytesStats, - statsOfPlanToCache, - outputOrdering) - - override def producedAttributes: AttributeSet = outputSet - - @transient val partitionStatistics = new PartitionStatistics(output) + def cachedColumnBuffers: RDD[CachedBatch] = { + if (_cachedColumnBuffers == null) { + synchronized { + if (_cachedColumnBuffers == null) { + _cachedColumnBuffers = buildBuffers() + } + } + } + _cachedColumnBuffers + } - override def computeStats(): Statistics = { - if (sizeInBytesStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. - // Note that we should drop the hint info here. We may cache a plan whose root node is a hint - // node. When we lookup the cache with a semantically same plan without hint info, the plan - // returned by cache lookup should not have hint info. If we lookup the cache with a - // semantically same plan with a different hint info, `CacheManager.useCachedData` will take - // care of it and retain the hint info in the lookup input plan. - statsOfPlanToCache.copy(hints = HintInfo()) - } else { - Statistics(sizeInBytes = sizeInBytesStats.value.longValue) + def clearCache(blocking: Boolean = true): Unit = { + if (_cachedColumnBuffers != null) { + synchronized { + if (_cachedColumnBuffers != null) { + _cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } + } } } - // If the cached column buffers were not passed in, we calculate them in the constructor. - // As in Spark, the actual work of caching is lazy. - if (_cachedColumnBuffers == null) { - buildBuffers() + def withCachedPlan(cachedPlan: SparkPlan): CachedRDDBuilder = { + new CachedRDDBuilder( + useCompression, + batchSize, + storageLevel, + cachedPlan = cachedPlan, + tableName + )(_cachedColumnBuffers) } - private def buildBuffers(): Unit = { - val output = child.output - val cached = child.execute().mapPartitionsInternal { rowIterator => + private def buildBuffers(): RDD[CachedBatch] = { + val output = cachedPlan.output + val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator => new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => @@ -154,32 +134,77 @@ case class InMemoryRelation( cached.setName( tableName.map(n => s"In-memory table $n") - .getOrElse(StringUtils.abbreviate(child.toString, 1024))) - _cachedColumnBuffers = cached + .getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024))) + cached + } +} + +object InMemoryRelation { + + def apply( + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + child: SparkPlan, + tableName: Option[String], + logicalPlan: LogicalPlan): InMemoryRelation = { + val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName)() + new InMemoryRelation(child.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } + + def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = { + new InMemoryRelation(cacheBuilder.cachedPlan.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } +} + +case class InMemoryRelation( + output: Seq[Attribute], + @transient cacheBuilder: CachedRDDBuilder)( + statsOfPlanToCache: Statistics, + override val outputOrdering: Seq[SortOrder]) + extends logical.LeafNode with MultiInstanceRelation { + + override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) + + override def doCanonicalize(): logical.LogicalPlan = + copy(output = output.map(QueryPlan.normalizeExprId(_, cachedPlan.output)), + cacheBuilder)( + statsOfPlanToCache, + outputOrdering) + + override def producedAttributes: AttributeSet = outputSet + + @transient val partitionStatistics = new PartitionStatistics(output) + + def cachedPlan: SparkPlan = cacheBuilder.cachedPlan + + override def computeStats(): Statistics = { + if (cacheBuilder.sizeInBytesStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. + // Note that we should drop the hint info here. We may cache a plan whose root node is a hint + // node. When we lookup the cache with a semantically same plan without hint info, the plan + // returned by cache lookup should not have hint info. If we lookup the cache with a + // semantically same plan with a different hint info, `CacheManager.useCachedData` will take + // care of it and retain the hint info in the lookup input plan. + statsOfPlanToCache.copy(hints = HintInfo()) + } else { + Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue) + } } def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { - InMemoryRelation( - newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering) + InMemoryRelation(newOutput, cacheBuilder)(statsOfPlanToCache, outputOrdering) } override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), - useCompression, - batchSize, - storageLevel, - child, - tableName)( - _cachedColumnBuffers, - sizeInBytesStats, + cacheBuilder)( statsOfPlanToCache, outputOrdering).asInstanceOf[this.type] } - def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers - - override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) + override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index e73e1378d52e3..997cf92449c68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ @@ -78,10 +78,12 @@ case class InMemoryTableScanExec( private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) - private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + private def createAndDecompressColumn( + cachedColumnarBatch: CachedBatch, + offHeapColumnVectorEnabled: Boolean): ColumnarBatch = { val rowCount = cachedColumnarBatch.numRows val taskContext = Option(TaskContext.get()) - val columnVectors = if (!conf.offHeapColumnVectorEnabled || taskContext.isEmpty) { + val columnVectors = if (!offHeapColumnVectorEnabled || taskContext.isEmpty) { OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) } else { OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) @@ -101,10 +103,13 @@ case class InMemoryTableScanExec( private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() + val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled if (supportsBatch) { // HACK ALERT: This is actually an RDD[ColumnarBatch]. // We're taking advantage of Scala's type erasure here to pass these batches along. - buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]] + buffers + .map(createAndDecompressColumn(_, offHeapColumnVectorEnabled)) + .asInstanceOf[RDD[InternalRow]] } else { val numOutputRows = longMetric("numOutputRows") @@ -154,7 +159,7 @@ case class InMemoryTableScanExec( private def updateAttribute(expr: Expression): Expression = { // attributes can be pruned so using relation's output. // E.g., relation.output is [id, item] but this scan's output can be [item] only. - val attrMap = AttributeMap(relation.child.output.zip(relation.output)) + val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } @@ -163,16 +168,16 @@ case class InMemoryTableScanExec( // The cached version does not change the outputPartitioning of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { - relation.child.outputPartitioning match { - case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.child.outputPartitioning + relation.cachedPlan.outputPartitioning match { + case e: Expression => updateAttribute(e).asInstanceOf[Partitioning] + case other => other } } // The cached version does not change the outputOrdering of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputOrdering: Seq[SortOrder] = - relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) + relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) // Keeps relation's partition statistics because we don't serialize relation. private val stats = relation.partitionStatistics @@ -252,7 +257,7 @@ case class InMemoryTableScanExec( // within the map Partitions closure. val schema = stats.schema val schemaIndex = schema.zipWithIndex - val buffers = relation.cachedColumnBuffers + val buffers = relation.cacheBuilder.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => val partitionFilter = newPredicate( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index bf4d96fa18d0d..04bf8c6dd917f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -189,8 +189,9 @@ case class DropTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog + val isTempView = catalog.isTemporaryTable(tableName) - if (!catalog.isTemporaryTable(tableName) && catalog.tableExists(tableName)) { + if (!isTempView && catalog.tableExists(tableName)) { // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view // issue an exception. catalog.getTableMetadata(tableName).tableType match { @@ -204,9 +205,10 @@ case class DropTableCommand( } } - if (catalog.isTemporaryTable(tableName) || catalog.tableExists(tableName)) { + if (isTempView || catalog.tableExists(tableName)) { try { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession.table(tableName), cascade = !isTempView) } catch { case NonFatal(e) => log.warn(e.toString, e) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 44749190c79eb..ec3961f84bd8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -493,7 +493,7 @@ case class TruncateTableCommand( spark.sessionState.refreshTable(tableName.unquotedString) // Also try to drop the contents of the table from the columnar cache try { - spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier)) + spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier), cascade = true) } catch { case NonFatal(e) => log.warn(s"Exception when attempting to uncache table $tableIdentWithDB", e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 69c03d862391e..ba7d2b7cbdb1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration /** - * Simple metrics collected during an instance of [[FileFormatWriter.ExecuteWriteTask]]. + * Simple metrics collected during an instance of [[FileFormatDataWriter]]. * These were first introduced in https://github.com/apache/spark/pull/18159 (SPARK-20703). */ case class BasicWriteTaskStats( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala index ea4fe9c8ade5f..a776fc3e7021d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning + object BucketingUtils { // The file name of bucketed data should have 3 parts: // 1. some other information in the head of file name @@ -35,5 +38,16 @@ object BucketingUtils { case other => None } + // Given bucketColumn, numBuckets and value, returns the corresponding bucketId + def getBucketIdFromValue(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { + val mutableInternalRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) + mutableInternalRow.update(0, value) + + val bucketIdGenerator = UnsafeProjection.create( + HashPartitioning(Seq(bucketColumn), numBuckets).partitionIdExpression :: Nil, + bucketColumn :: Nil) + bucketIdGenerator(mutableInternalRow).getInt(0) + } + def bucketIdToString(id: Int): String = f"_$id%05d" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f16d824201e77..0c3d9a4895fe2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -396,6 +396,7 @@ case class DataSource( hs.partitionSchema.map(_.name), "in the partition schema", equality) + DataSourceUtils.verifyReadSchema(hs.fileFormat, hs.dataSchema) case _ => SchemaUtils.checkColumnNameDuplication( relation.schema.map(_.name), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 3f41612c08065..7b129435c45db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -312,18 +312,6 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with case _ => Nil } - // Get the bucket ID based on the bucketing values. - // Restriction: Bucket pruning works iff the bucketing column has one and only one column. - def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { - val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) - mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) - val bucketIdGeneration = UnsafeProjection.create( - HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, - bucketColumn :: Nil) - - bucketIdGeneration(mutableRow).getInt(0) - } - // Based on Public API. private def pruneFilterProject( relation: LogicalRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala new file mode 100644 index 0000000000000..82e99190ecf14 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types._ + + +object DataSourceUtils { + + /** + * Verify if the schema is supported in datasource in write path. + */ + def verifyWriteSchema(format: FileFormat, schema: StructType): Unit = { + verifySchema(format, schema, isReadPath = false) + } + + /** + * Verify if the schema is supported in datasource in read path. + */ + def verifyReadSchema(format: FileFormat, schema: StructType): Unit = { + verifySchema(format, schema, isReadPath = true) + } + + /** + * Verify if the schema is supported in datasource. This verification should be done + * in a driver side. + */ + private def verifySchema(format: FileFormat, schema: StructType, isReadPath: Boolean): Unit = { + schema.foreach { field => + if (!format.supportDataType(field.dataType, isReadPath)) { + throw new AnalysisException( + s"$format data source does not support ${field.dataType.simpleString} data type.") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 023e127888290..2c162e23644ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} /** @@ -57,7 +57,7 @@ trait FileFormat { dataSchema: StructType): OutputWriterFactory /** - * Returns whether this format support returning columnar batch or not. + * Returns whether this format supports returning columnar batch or not. * * TODO: we should just have different traits for the different formats. */ @@ -152,6 +152,11 @@ trait FileFormat { } } + /** + * Returns whether this format supports the given [[DataType]] in read/write path. + * By default all data types are supported. + */ + def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala new file mode 100644 index 0000000000000..6499328e89ce7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StringType +import org.apache.spark.util.SerializableConfiguration + +/** + * Abstract class for writing out data in a single Spark task. + * Exceptions thrown by the implementation of this trait will automatically trigger task aborts. + */ +abstract class FileFormatDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) { + /** + * Max number of files a single task writes out due to file size. In most cases the number of + * files written should be very small. This is just a safe guard to protect some really bad + * settings, e.g. maxRecordsPerFile = 1. + */ + protected val MAX_FILE_COUNTER: Int = 1000 * 1000 + protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]() + protected var currentWriter: OutputWriter = _ + + /** Trackers for computing various statistics on the data as it's being written out. */ + protected val statsTrackers: Seq[WriteTaskStatsTracker] = + description.statsTrackers.map(_.newTaskInstance()) + + protected def releaseResources(): Unit = { + if (currentWriter != null) { + try { + currentWriter.close() + } finally { + currentWriter = null + } + } + } + + /** Writes a record */ + def write(record: InternalRow): Unit + + /** + * Returns the summary of relative information which + * includes the list of partition strings written out. The list of partitions is sent back + * to the driver and used to update the catalog. Other information will be sent back to the + * driver too and used to e.g. update the metrics in UI. + */ + def commit(): WriteTaskResult = { + releaseResources() + val summary = ExecutedWriteSummary( + updatedPartitions = updatedPartitions.toSet, + stats = statsTrackers.map(_.getFinalStats())) + WriteTaskResult(committer.commitTask(taskAttemptContext), summary) + } + + def abort(): Unit = { + try { + releaseResources() + } finally { + committer.abortTask(taskAttemptContext) + } + } +} + +/** FileFormatWriteTask for empty partitions */ +class EmptyDirectoryDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol +) extends FileFormatDataWriter(description, taskAttemptContext, committer) { + override def write(record: InternalRow): Unit = {} +} + +/** Writes data to a single directory (used for non-dynamic-partition writes). */ +class SingleDirectoryDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends FileFormatDataWriter(description, taskAttemptContext, committer) { + private var fileCounter: Int = _ + private var recordsInFile: Long = _ + // Initialize currentWriter and statsTrackers + newOutputWriter() + + private def newOutputWriter(): Unit = { + recordsInFile = 0 + releaseResources() + + val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) + val currentPath = committer.newTaskTempFile( + taskAttemptContext, + None, + f"-c$fileCounter%03d" + ext) + + currentWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + + statsTrackers.foreach(_.newFile(currentPath)) + } + + override def write(record: InternalRow): Unit = { + if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + newOutputWriter() + } + + currentWriter.write(record) + statsTrackers.foreach(_.newRow(record)) + recordsInFile += 1 + } +} + +/** + * Writes data to using dynamic partition writes, meaning this single function can write to + * multiple directories (partitions) or files (bucketing). + */ +class DynamicPartitionDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends FileFormatDataWriter(description, taskAttemptContext, committer) { + + /** Flag saying whether or not the data to be written out is partitioned. */ + private val isPartitioned = description.partitionColumns.nonEmpty + + /** Flag saying whether or not the data to be written out is bucketed. */ + private val isBucketed = description.bucketIdExpression.isDefined + + assert(isPartitioned || isBucketed, + s"""DynamicPartitionWriteTask should be used for writing out data that's either + |partitioned or bucketed. In this case neither is true. + |WriteJobDescription: $description + """.stripMargin) + + private var fileCounter: Int = _ + private var recordsInFile: Long = _ + private var currentPartionValues: Option[UnsafeRow] = None + private var currentBucketId: Option[Int] = None + + /** Extracts the partition values out of an input row. */ + private lazy val getPartitionValues: InternalRow => UnsafeRow = { + val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns) + row => proj(row) + } + + /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */ + private lazy val partitionPathExpression: Expression = Concat( + description.partitionColumns.zipWithIndex.flatMap { case (c, i) => + val partitionName = ScalaUDF( + ExternalCatalogUtils.getPartitionPathString _, + StringType, + Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId)))) + if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) + }) + + /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns + * the partition string. */ + private lazy val getPartitionPath: InternalRow => String = { + val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns) + row => proj(row).getString(0) + } + + /** Given an input row, returns the corresponding `bucketId` */ + private lazy val getBucketId: InternalRow => Int = { + val proj = + UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns) + row => proj(row).getInt(0) + } + + /** Returns the data columns to be written given an input row */ + private val getOutputRow = + UnsafeProjection.create(description.dataColumns, description.allColumns) + + /** + * Opens a new OutputWriter given a partition key and/or a bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + * + * @param partitionValues the partition which all tuples being written by this `OutputWriter` + * belong to + * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + */ + private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = { + recordsInFile = 0 + releaseResources() + + val partDir = partitionValues.map(getPartitionPath(_)) + partDir.foreach(updatedPartitions.add) + + val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + + // This must be in a form that matches our bucketing format. See BucketingUtils. + val ext = f"$bucketIdStr.c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) + + val customPath = partDir.flatMap { dir => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + } + val currentPath = if (customPath.isDefined) { + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + } else { + committer.newTaskTempFile(taskAttemptContext, partDir, ext) + } + + currentWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + + statsTrackers.foreach(_.newFile(currentPath)) + } + + override def write(record: InternalRow): Unit = { + val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None + val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None + + if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + if (isPartitioned && currentPartionValues != nextPartitionValues) { + currentPartionValues = Some(nextPartitionValues.get.copy()) + statsTrackers.foreach(_.newPartition(currentPartionValues.get)) + } + if (isBucketed) { + currentBucketId = nextBucketId + statsTrackers.foreach(_.newBucket(currentBucketId.get)) + } + + fileCounter = 0 + newOutputWriter(currentPartionValues, currentBucketId) + } else if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + newOutputWriter(currentPartionValues, currentBucketId) + } + val outputRow = getOutputRow(record) + currentWriter.write(outputRow) + statsTrackers.foreach(_.newRow(outputRow)) + recordsInFile += 1 + } +} + +/** A shared job description for all the write tasks. */ +class WriteJobDescription( + val uuid: String, // prevent collision between different (appending) write jobs + val serializableHadoopConf: SerializableConfiguration, + val outputWriterFactory: OutputWriterFactory, + val allColumns: Seq[Attribute], + val dataColumns: Seq[Attribute], + val partitionColumns: Seq[Attribute], + val bucketIdExpression: Option[Expression], + val path: String, + val customPartitionLocations: Map[TablePartitionSpec, String], + val maxRecordsPerFile: Long, + val timeZoneId: String, + val statsTrackers: Seq[WriteJobStatsTracker]) + extends Serializable { + + assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), + s""" + |All columns: ${allColumns.mkString(", ")} + |Partition columns: ${partitionColumns.mkString(", ")} + |Data columns: ${dataColumns.mkString(", ")} + """.stripMargin) +} + +/** The result of a successful write task. */ +case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) + +/** + * Wrapper class for the metrics of writing data out. + * + * @param updatedPartitions the partitions updated during writing data out. Only valid + * for dynamic partition. + * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had. + */ +case class ExecutedWriteSummary( + updatedPartitions: Set[String], + stats: Seq[WriteTaskStats]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 401597f967218..7c6ab4bc922fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} -import scala.collection.mutable - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ @@ -30,62 +28,25 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} -import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} -import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} /** A helper object for writing FileFormat data out to a location. */ object FileFormatWriter extends Logging { - - /** - * Max number of files a single task writes out due to file size. In most cases the number of - * files written should be very small. This is just a safe guard to protect some really bad - * settings, e.g. maxRecordsPerFile = 1. - */ - private val MAX_FILE_COUNTER = 1000 * 1000 - /** Describes how output files should be placed in the filesystem. */ case class OutputSpec( - outputPath: String, - customPartitionLocations: Map[TablePartitionSpec, String], - outputColumns: Seq[Attribute]) - - /** A shared job description for all the write tasks. */ - private class WriteJobDescription( - val uuid: String, // prevent collision between different (appending) write jobs - val serializableHadoopConf: SerializableConfiguration, - val outputWriterFactory: OutputWriterFactory, - val allColumns: Seq[Attribute], - val dataColumns: Seq[Attribute], - val partitionColumns: Seq[Attribute], - val bucketIdExpression: Option[Expression], - val path: String, - val customPartitionLocations: Map[TablePartitionSpec, String], - val maxRecordsPerFile: Long, - val timeZoneId: String, - val statsTrackers: Seq[WriteJobStatsTracker]) - extends Serializable { - - assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), - s""" - |All columns: ${allColumns.mkString(", ")} - |Partition columns: ${partitionColumns.mkString(", ")} - |Data columns: ${dataColumns.mkString(", ")} - """.stripMargin) - } - - /** The result of a successful write task. */ - private case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) + outputPath: String, + customPartitionLocations: Map[TablePartitionSpec, String], + outputColumns: Seq[Attribute]) /** * Basic work flow of this command is: @@ -135,9 +96,11 @@ object FileFormatWriter extends Logging { val caseInsensitiveOptions = CaseInsensitiveMap(options) + val dataSchema = dataColumns.toStructType + DataSourceUtils.verifyWriteSchema(fileFormat, dataSchema) // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = - fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType) + fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema) val description = new WriteJobDescription( uuid = UUID.randomUUID().toString, @@ -262,30 +225,27 @@ object FileFormatWriter extends Logging { committer.setupTask(taskAttemptContext) - val writeTask = + val dataWriter = if (sparkPartitionId != 0 && !iterator.hasNext) { // In case of empty job, leave first partition to save meta for file format like parquet. - new EmptyDirectoryWriteTask(description) + new EmptyDirectoryDataWriter(description, taskAttemptContext, committer) } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { - new SingleDirectoryWriteTask(description, taskAttemptContext, committer) + new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionWriteTask(description, taskAttemptContext, committer) + new DynamicPartitionDataWriter(description, taskAttemptContext, committer) } try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. - val summary = writeTask.execute(iterator) - writeTask.releaseResources() - WriteTaskResult(committer.commitTask(taskAttemptContext), summary) - })(catchBlock = { - // If there is an error, release resource and then abort the task - try { - writeTask.releaseResources() - } finally { - committer.abortTask(taskAttemptContext) - logError(s"Job $jobId aborted.") + while (iterator.hasNext) { + dataWriter.write(iterator.next()) } + dataWriter.commit() + })(catchBlock = { + // If there is an error, abort the task + dataWriter.abort() + logError(s"Job $jobId aborted.") }) } catch { case e: FetchFailedException => @@ -302,7 +262,7 @@ object FileFormatWriter extends Logging { private def processStats( statsTrackers: Seq[WriteJobStatsTracker], statsPerTask: Seq[Seq[WriteTaskStats]]) - : Unit = { + : Unit = { val numStatsTrackers = statsTrackers.length assert(statsPerTask.forall(_.length == numStatsTrackers), @@ -321,281 +281,4 @@ object FileFormatWriter extends Logging { case (statsTracker, stats) => statsTracker.processStats(stats) } } - - /** - * A simple trait for writing out data in a single Spark task, without any concerns about how - * to commit or abort tasks. Exceptions thrown by the implementation of this trait will - * automatically trigger task aborts. - */ - private trait ExecuteWriteTask { - - /** - * Writes data out to files, and then returns the summary of relative information which - * includes the list of partition strings written out. The list of partitions is sent back - * to the driver and used to update the catalog. Other information will be sent back to the - * driver too and used to e.g. update the metrics in UI. - */ - def execute(iterator: Iterator[InternalRow]): ExecutedWriteSummary - def releaseResources(): Unit - } - - /** ExecuteWriteTask for empty partitions */ - private class EmptyDirectoryWriteTask(description: WriteJobDescription) - extends ExecuteWriteTask { - - val statsTrackers: Seq[WriteTaskStatsTracker] = - description.statsTrackers.map(_.newTaskInstance()) - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - ExecutedWriteSummary( - updatedPartitions = Set.empty, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = {} - } - - /** Writes data to a single directory (used for non-dynamic-partition writes). */ - private class SingleDirectoryWriteTask( - description: WriteJobDescription, - taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) extends ExecuteWriteTask { - - private[this] var currentWriter: OutputWriter = _ - - val statsTrackers: Seq[WriteTaskStatsTracker] = - description.statsTrackers.map(_.newTaskInstance()) - - private def newOutputWriter(fileCounter: Int): Unit = { - val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) - val currentPath = committer.newTaskTempFile( - taskAttemptContext, - None, - f"-c$fileCounter%03d" + ext) - - currentWriter = description.outputWriterFactory.newInstance( - path = currentPath, - dataSchema = description.dataColumns.toStructType, - context = taskAttemptContext) - - statsTrackers.map(_.newFile(currentPath)) - } - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - var fileCounter = 0 - var recordsInFile: Long = 0L - newOutputWriter(fileCounter) - - while (iter.hasNext) { - if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - - recordsInFile = 0 - releaseResources() - newOutputWriter(fileCounter) - } - - val internalRow = iter.next() - currentWriter.write(internalRow) - statsTrackers.foreach(_.newRow(internalRow)) - recordsInFile += 1 - } - releaseResources() - ExecutedWriteSummary( - updatedPartitions = Set.empty, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = { - if (currentWriter != null) { - try { - currentWriter.close() - } finally { - currentWriter = null - } - } - } - } - - /** - * Writes data to using dynamic partition writes, meaning this single function can write to - * multiple directories (partitions) or files (bucketing). - */ - private class DynamicPartitionWriteTask( - desc: WriteJobDescription, - taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) extends ExecuteWriteTask { - - /** Flag saying whether or not the data to be written out is partitioned. */ - val isPartitioned = desc.partitionColumns.nonEmpty - - /** Flag saying whether or not the data to be written out is bucketed. */ - val isBucketed = desc.bucketIdExpression.isDefined - - assert(isPartitioned || isBucketed, - s"""DynamicPartitionWriteTask should be used for writing out data that's either - |partitioned or bucketed. In this case neither is true. - |WriteJobDescription: ${desc} - """.stripMargin) - - // currentWriter is initialized whenever we see a new key (partitionValues + BucketId) - private var currentWriter: OutputWriter = _ - - /** Trackers for computing various statistics on the data as it's being written out. */ - private val statsTrackers: Seq[WriteTaskStatsTracker] = - desc.statsTrackers.map(_.newTaskInstance()) - - /** Extracts the partition values out of an input row. */ - private lazy val getPartitionValues: InternalRow => UnsafeRow = { - val proj = UnsafeProjection.create(desc.partitionColumns, desc.allColumns) - row => proj(row) - } - - /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */ - private lazy val partitionPathExpression: Expression = Concat( - desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => - val partitionName = ScalaUDF( - ExternalCatalogUtils.getPartitionPathString _, - StringType, - Seq(Literal(c.name), Cast(c, StringType, Option(desc.timeZoneId)))) - if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) - }) - - /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns - * the partition string. */ - private lazy val getPartitionPath: InternalRow => String = { - val proj = UnsafeProjection.create(Seq(partitionPathExpression), desc.partitionColumns) - row => proj(row).getString(0) - } - - /** Given an input row, returns the corresponding `bucketId` */ - private lazy val getBucketId: InternalRow => Int = { - val proj = UnsafeProjection.create(desc.bucketIdExpression.toSeq, desc.allColumns) - row => proj(row).getInt(0) - } - - /** Returns the data columns to be written given an input row */ - private val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns) - - /** - * Opens a new OutputWriter given a partition key and/or a bucket id. - * If bucket id is specified, we will append it to the end of the file name, but before the - * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet - * - * @param partitionValues the partition which all tuples being written by this `OutputWriter` - * belong to - * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to - * @param fileCounter the number of files that have been written in the past for this specific - * partition. This is used to limit the max number of records written for a - * single file. The value should start from 0. - * @param updatedPartitions the set of updated partition paths, we should add the new partition - * path of this writer to it. - */ - private def newOutputWriter( - partitionValues: Option[InternalRow], - bucketId: Option[Int], - fileCounter: Int, - updatedPartitions: mutable.Set[String]): Unit = { - - val partDir = partitionValues.map(getPartitionPath(_)) - partDir.foreach(updatedPartitions.add) - - val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - - // This must be in a form that matches our bucketing format. See BucketingUtils. - val ext = f"$bucketIdStr.c$fileCounter%03d" + - desc.outputWriterFactory.getFileExtension(taskAttemptContext) - - val customPath = partDir.flatMap { dir => - desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) - } - val currentPath = if (customPath.isDefined) { - committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) - } else { - committer.newTaskTempFile(taskAttemptContext, partDir, ext) - } - - currentWriter = desc.outputWriterFactory.newInstance( - path = currentPath, - dataSchema = desc.dataColumns.toStructType, - context = taskAttemptContext) - - statsTrackers.foreach(_.newFile(currentPath)) - } - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - // If anything below fails, we should abort the task. - var recordsInFile: Long = 0L - var fileCounter = 0 - val updatedPartitions = mutable.Set[String]() - var currentPartionValues: Option[UnsafeRow] = None - var currentBucketId: Option[Int] = None - - for (row <- iter) { - val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None - val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None - - if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { - // See a new partition or bucket - write to a new partition dir (or a new bucket file). - if (isPartitioned && currentPartionValues != nextPartitionValues) { - currentPartionValues = Some(nextPartitionValues.get.copy()) - statsTrackers.foreach(_.newPartition(currentPartionValues.get)) - } - if (isBucketed) { - currentBucketId = nextBucketId - statsTrackers.foreach(_.newBucket(currentBucketId.get)) - } - - recordsInFile = 0 - fileCounter = 0 - - releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) - } else if (desc.maxRecordsPerFile > 0 && - recordsInFile >= desc.maxRecordsPerFile) { - // Exceeded the threshold in terms of the number of records per file. - // Create a new file by increasing the file counter. - recordsInFile = 0 - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - - releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) - } - val outputRow = getOutputRow(row) - currentWriter.write(outputRow) - statsTrackers.foreach(_.newRow(outputRow)) - recordsInFile += 1 - } - releaseResources() - - ExecutedWriteSummary( - updatedPartitions = updatedPartitions.toSet, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = { - if (currentWriter != null) { - try { - currentWriter.close() - } finally { - currentWriter = null - } - } - } - } } - -/** - * Wrapper class for the metrics of writing data out. - * - * @param updatedPartitions the partitions updated during writing data out. Only valid - * for dynamic partition. - * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had. - */ -case class ExecutedWriteSummary( - updatedPartitions: Set[String], - stats: Seq[WriteTaskStats]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 16b22717b8d92..fe27b78bf3360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.FileSourceScanExec -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} +import org.apache.spark.util.collection.BitSet /** * A strategy for planning scans over collections of files that might be partitioned or bucketed @@ -50,6 +51,91 @@ import org.apache.spark.sql.execution.SparkPlan * and add it. Proceed to the next file. */ object FileSourceStrategy extends Strategy with Logging { + + // should prune buckets iff num buckets is greater than 1 and there is only one bucket column + private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = { + bucketSpec match { + case Some(spec) => spec.bucketColumnNames.length == 1 && spec.numBuckets > 1 + case None => false + } + } + + private def getExpressionBuckets( + expr: Expression, + bucketColumnName: String, + numBuckets: Int): BitSet = { + + def getBucketNumber(attr: Attribute, v: Any): Int = { + BucketingUtils.getBucketIdFromValue(attr, numBuckets, v) + } + + def getBucketSetFromIterable(attr: Attribute, iter: Iterable[Any]): BitSet = { + val matchedBuckets = new BitSet(numBuckets) + iter + .map(v => getBucketNumber(attr, v)) + .foreach(bucketNum => matchedBuckets.set(bucketNum)) + matchedBuckets + } + + def getBucketSetFromValue(attr: Attribute, v: Any): BitSet = { + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.set(getBucketNumber(attr, v)) + matchedBuckets + } + + expr match { + case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + getBucketSetFromValue(a, v) + case expressions.In(a: Attribute, list) + if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) + case expressions.InSet(a: Attribute, hset) + if hset.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + getBucketSetFromIterable(a, hset.map(e => expressions.Literal(e).eval(EmptyRow))) + case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => + getBucketSetFromValue(a, null) + case expressions.And(left, right) => + getExpressionBuckets(left, bucketColumnName, numBuckets) & + getExpressionBuckets(right, bucketColumnName, numBuckets) + case expressions.Or(left, right) => + getExpressionBuckets(left, bucketColumnName, numBuckets) | + getExpressionBuckets(right, bucketColumnName, numBuckets) + case _ => + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.setUntil(numBuckets) + matchedBuckets + } + } + + private def genBucketSet( + normalizedFilters: Seq[Expression], + bucketSpec: BucketSpec): Option[BitSet] = { + if (normalizedFilters.isEmpty) { + return None + } + + val bucketColumnName = bucketSpec.bucketColumnNames.head + val numBuckets = bucketSpec.numBuckets + + val normalizedFiltersAndExpr = normalizedFilters + .reduce(expressions.And) + val matchedBuckets = getExpressionBuckets(normalizedFiltersAndExpr, bucketColumnName, + numBuckets) + + val numBucketsSelected = matchedBuckets.cardinality() + + logInfo { + s"Pruned ${numBuckets - numBucketsSelected} out of $numBuckets buckets." + } + + // None means all the buckets need to be scanned + if (numBucketsSelected == numBuckets) { + None + } else { + Some(matchedBuckets) + } + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => @@ -76,9 +162,19 @@ object FileSourceStrategy extends Strategy with Logging { fsRelation.partitionSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) + logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") + val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec + val bucketSet = if (shouldPruneBuckets(bucketSpec)) { + genBucketSet(normalizedFilters, bucketSpec.get) + } else { + None + } + val dataColumns = l.resolve(fsRelation.dataSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) @@ -108,6 +204,7 @@ object FileSourceStrategy extends Strategy with Logging { outputAttributes, outputSchema, partitionKeyFilters.toSeq, + bucketSet, dataFilters, table.map(_.identifier)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 739d1f456e3ec..9d9f8bd5bb58e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -294,9 +294,12 @@ object InMemoryFileIndex extends Logging { if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles } - allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { + val missingFiles = mutable.ArrayBuffer.empty[String] + val filteredLeafStatuses = allLeafStatuses.filterNot( + status => shouldFilterOut(status.getPath.getName)) + val resolvedLeafStatuses = filteredLeafStatuses.flatMap { case f: LocatedFileStatus => - f + Some(f) // NOTE: // @@ -311,14 +314,27 @@ object InMemoryFileIndex extends Logging { // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), // which is very slow on some file system (RawLocalFileSystem, which is launch a // subprocess and parse the stdout). - val locations = fs.getFileBlockLocations(f, 0, f.getLen) - val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, - f.getModificationTime, 0, null, null, null, null, f.getPath, locations) - if (f.isSymlink) { - lfs.setSymlink(f.getSymlink) + try { + val locations = fs.getFileBlockLocations(f, 0, f.getLen) + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) + } + Some(lfs) + } catch { + case _: FileNotFoundException => + missingFiles += f.getPath.toString + None } - lfs } + + if (missingFiles.nonEmpty) { + logWarning( + s"the following files were missing during file scan:\n ${missingFiles.mkString("\n ")}") + } + + resolvedLeafStatuses } /** Checks if we should filter out this path name. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index a813829d50cb1..80d7608a22891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -38,9 +38,8 @@ case class InsertIntoDataSourceCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = Dataset.ofRows(sparkSession, query) - // Apply the schema of the existing table to the new data. - val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) + // Data has been casted to the target relation's schema by the PreprocessTableInsertion rule. + relation.insert(data, overwrite) // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this // data source relation. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 3b830accb83f0..16b2367bfdd5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -55,7 +55,9 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 568e953a5db66..00b1b5dedb593 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.CreatableRelationProvider -import org.apache.spark.util.Utils /** * Saves the results of `query` in to a data source. @@ -50,7 +49,7 @@ case class SaveIntoDataSourceCommand( } override def simpleString: String = { - val redacted = Utils.redact(SparkEnv.get.conf, options.toSeq).toMap + val redacted = SQLConf.get.redactOptions(options) s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 4870d75fc5f08..82322df407521 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow @@ -50,7 +51,10 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] + requiredSchema: StructType, + // Actual schema of data in the csv file + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -110,7 +114,7 @@ abstract class CSVDataSource extends Serializable { } } -object CSVDataSource { +object CSVDataSource extends Logging { def apply(options: CSVOptions): CSVDataSource = { if (options.multiLine) { MultiLineCSVDataSource @@ -118,6 +122,84 @@ object CSVDataSource { TextInputCSVDataSource } } + + /** + * Checks that column names in a CSV header and field names in the schema are the same + * by taking into account case sensitivity. + * + * @param schema - provided (or inferred) schema to which CSV must conform. + * @param columnNames - names of CSV columns that must be checked against to the schema. + * @param fileName - name of CSV file that are currently checked. It is used in error messages. + * @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column + * names are checked for conformance to the schema. In the case if + * the column name don't conform to the schema, an exception is thrown. + * @param caseSensitive - if it is set to `false`, comparison of column names and schema field + * names is not case sensitive. + */ + def checkHeaderColumnNames( + schema: StructType, + columnNames: Array[String], + fileName: String, + enforceSchema: Boolean, + caseSensitive: Boolean): Unit = { + if (columnNames != null) { + val fieldNames = schema.map(_.name).toIndexedSeq + val (headerLen, schemaSize) = (columnNames.size, fieldNames.length) + var errorMessage: Option[String] = None + + if (headerLen == schemaSize) { + var i = 0 + while (errorMessage.isEmpty && i < headerLen) { + var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) + if (!caseSensitive) { + nameInSchema = nameInSchema.toLowerCase + nameInHeader = nameInHeader.toLowerCase + } + if (nameInHeader != nameInSchema) { + errorMessage = Some( + s"""|CSV header does not conform to the schema. + | Header: ${columnNames.mkString(", ")} + | Schema: ${fieldNames.mkString(", ")} + |Expected: ${fieldNames(i)} but found: ${columnNames(i)} + |CSV file: $fileName""".stripMargin) + } + i += 1 + } + } else { + errorMessage = Some( + s"""|Number of column in CSV header is not equal to number of fields in the schema: + | Header length: $headerLen, schema size: $schemaSize + |CSV file: $fileName""".stripMargin) + } + + errorMessage.foreach { msg => + if (enforceSchema) { + logWarning(msg) + } else { + throw new IllegalArgumentException(msg) + } + } + } + } + + /** + * Checks that CSV header contains the same column names as fields names in the given schema + * by taking into account case sensitivity. + */ + def checkHeader( + header: String, + parser: CsvParser, + schema: StructType, + fileName: String, + enforceSchema: Boolean, + caseSensitive: Boolean): Unit = { + checkHeaderColumnNames( + schema, + parser.parseLine(header), + fileName, + enforceSchema, + caseSensitive) + } } object TextInputCSVDataSource extends CSVDataSource { @@ -127,7 +209,9 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + requiredSchema: StructType, + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) @@ -136,8 +220,24 @@ object TextInputCSVDataSource extends CSVDataSource { } } - val shouldDropHeader = parser.options.headerFlag && file.start == 0 - UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) + val hasHeader = parser.options.headerFlag && file.start == 0 + if (hasHeader) { + // Checking that column names in the header are matched to field names of the schema. + // The header will be removed from lines. + // Note: if there are only comments in the first block, the header would probably + // be not extracted. + CSVUtils.extractHeader(lines, parser.options).foreach { header => + CSVDataSource.checkHeader( + header, + parser.tokenizer, + dataSchema, + file.filePath, + parser.options.enforceSchema, + caseSensitive) + } + } + + UnivocityParser.parseIterator(lines, parser, requiredSchema) } override def infer( @@ -161,7 +261,8 @@ object TextInputCSVDataSource extends CSVDataSource { val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => + val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) + val tokenRDD = sampled.rdd.mapPartitions { iter => val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) @@ -184,7 +285,8 @@ object TextInputCSVDataSource extends CSVDataSource { DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = options.parameters ).resolveRelation(checkFilesExist = false)) .select("value").as[String](Encoders.STRING) } else { @@ -204,12 +306,24 @@ object MultiLineCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + requiredSchema: StructType, + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] = { + def checkHeader(header: Array[String]): Unit = { + CSVDataSource.checkHeaderColumnNames( + dataSchema, + header, + file.filePath, + parser.options.enforceSchema, + caseSensitive) + } + UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), parser.options.headerFlag, parser, - schema) + requiredSchema, + checkHeader) } override def infer( @@ -235,7 +349,8 @@ object MultiLineCSVDataSource extends CSVDataSource { parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) + val sampled = CSVUtils.sample(tokenRDD, parsedOptions) + CSVInferSchema.infer(sampled, header, parsedOptions) case None => // If the first row could not be read, just return the empty schema. StructType(Nil) @@ -248,7 +363,8 @@ object MultiLineCSVDataSource extends CSVDataSource { options: CSVOptions): RDD[PortableDataStream] = { val paths = inputPaths.map(_.getPath) val name = paths.mkString(",") - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions( + options.parameters)) FileInputFormat.setInputPaths(job, paths: _*) val conf = job.getConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index e20977a4ec79f..aeb40e5a4131d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -41,8 +41,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = - new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) val csvDataSource = CSVDataSource(parsedOptions) csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path) } @@ -51,8 +53,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = - new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) } @@ -62,9 +66,11 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - CSVUtils.verifySchema(dataSchema) val conf = job.getConfiguration - val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val csvOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) csvOptions.compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } @@ -91,12 +97,12 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - CSVUtils.verifySchema(dataSchema) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val parsedOptions = new CSVOptions( options, + sparkSession.sessionState.conf.csvColumnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -122,6 +128,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { "df.filter($\"_corrupt_record\".isNotNull).count()." ) } + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -129,7 +136,13 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), parsedOptions) - CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema) + CSVDataSource(parsedOptions).readFile( + conf, + file, + parser, + requiredSchema, + dataSchema, + caseSensitive) } } @@ -138,6 +151,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } + } private[csv] class CsvOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index c16790630ce17..fab8d62da0c1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -27,17 +27,20 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ class CSVOptions( - @transient private val parameters: CaseInsensitiveMap[String], + @transient val parameters: CaseInsensitiveMap[String], + val columnPruning: Boolean, defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { def this( parameters: Map[String, String], + columnPruning: Boolean, defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String = "") = { this( CaseInsensitiveMap(parameters), + columnPruning, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) } @@ -150,6 +153,15 @@ class CSVOptions( val isCommentSet = this.comment != '\u0000' + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + + /** + * Forcibly apply the specified or inferred schema to datasource files. + * If the option is enabled, headers of CSV files will be ignored. + */ + val enforceSchema = getBool("enforceSchema", default = true) + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat @@ -161,7 +173,7 @@ class CSVOptions( writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) - writerSettings.setEmptyValue(nullValue) + writerSettings.setEmptyValue("\"\"") writerSettings.setSkipEmptyLines(true) writerSettings.setQuoteAllFields(quoteAll) writerSettings.setQuoteEscapingEnabled(escapeQuotes) @@ -182,6 +194,7 @@ class CSVOptions( settings.setInputBufferSize(inputBufferSize) settings.setMaxColumns(maxColumns) settings.setNullValue(nullValue) + settings.setEmptyValue("") settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) settings diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 72b053d2092ca..7ce65fa89b02d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.csv +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -67,12 +68,8 @@ object CSVUtils { } } - /** - * Drop header line so that only data can remain. - * This is similar with `filterHeaderLine` above and currently being used in CSV reading path. - */ - def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = { - val nonEmptyLines = if (options.isCommentSet) { + def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + if (options.isCommentSet) { val commentPrefix = options.comment.toString iter.dropWhile { line => line.trim.isEmpty || line.trim.startsWith(commentPrefix) @@ -80,11 +77,19 @@ object CSVUtils { } else { iter.dropWhile(_.trim.isEmpty) } - - if (nonEmptyLines.hasNext) nonEmptyLines.drop(1) - iter } + /** + * Extracts header and moves iterator forward so that only data remains in it + */ + def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = { + val nonEmptyLines = skipComments(iter, options) + if (nonEmptyLines.hasNext) { + Some(nonEmptyLines.next()) + } else { + None + } + } /** * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one @@ -113,22 +118,28 @@ object CSVUtils { } /** - * Verify if the schema is supported in CSV datasource. + * Sample CSV dataset as configured by `samplingRatio`. */ - def verifySchema(schema: StructType): Unit = { - def verifyType(dataType: DataType): Unit = dataType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | - DoubleType | BooleanType | _: DecimalType | TimestampType | - DateType | StringType => - - case udt: UserDefinedType[_] => verifyType(udt.sqlType) - - case _ => - throw new UnsupportedOperationException( - s"CSV data source does not support ${dataType.simpleString} data type.") + def sample(csv: Dataset[String], options: CSVOptions): Dataset[String] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) } - - schema.foreach(field => verifyType(field.dataType)) } + /** + * Sample CSV RDD as configured by `samplingRatio`. + */ + def sample(csv: RDD[Array[String]], options: CSVOptions): RDD[Array[String]] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 3d6cc30f2ba83..aa545e1a0c00a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.InputStream import java.math.BigDecimal -import java.text.NumberFormat -import java.util.Locale import scala.util.Try import scala.util.control.NonFatal @@ -36,10 +34,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( - schema: StructType, + dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { - require(requiredSchema.toSet.subsetOf(schema.toSet), + require(requiredSchema.toSet.subsetOf(dataSchema.toSet), "requiredSchema should be the subset of schema.") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) @@ -47,9 +45,17 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = new CsvParser(options.asParserSettings) + val tokenizer = { + val parserSetting = options.asParserSettings + if (options.columnPruning && requiredSchema.length < dataSchema.length) { + val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + parserSetting.selectIndexes(tokenIndexArr: _*) + } + new CsvParser(parserSetting) + } + private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(requiredSchema.length) + private val row = new GenericInternalRow(schema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -75,11 +81,8 @@ class UnivocityParser( // Each input token is placed in each output row's position by mapping these. In this case, // // output row - ["A", 2] - private val valueConverters: Array[ValueConverter] = + private val valueConverters: Array[ValueConverter] = { schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - private val tokenIndexArr: Array[Int] = { - requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -180,11 +183,19 @@ class UnivocityParser( } } + private val doParse = if (schema.nonEmpty) { + (input: String) => convert(tokenizer.parseLine(input)) + } else { + // If `columnPruning` enabled and partition attributes scanned only, + // `schema` gets empty. + (_: String) => InternalRow.empty + } + /** * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) + def parse(input: String): InternalRow = doParse(input) private def convert(tokens: Array[String]): InternalRow = { if (tokens.length != schema.length) { @@ -212,9 +223,8 @@ class UnivocityParser( } else { try { var i = 0 - while (i < requiredSchema.length) { - val from = tokenIndexArr(i) - row(i) = valueConverters(from).apply(tokens(from)) + while (i < schema.length) { + row(i) = valueConverters(i).apply(tokens(i)) i += 1 } row @@ -248,14 +258,15 @@ private[csv] object UnivocityParser { inputStream: InputStream, shouldDropHeader: Boolean, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + schema: StructType, + checkHeader: Array[String] => Unit): Iterator[InternalRow] = { val tokenizer = parser.tokenizer val safeParser = new FailureSafeParser[Array[String]]( input => Seq(parser.convert(input)), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) - convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => + convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens => safeParser.parse(tokens) }.flatten } @@ -263,11 +274,14 @@ private[csv] object UnivocityParser { private def convertStream[T]( inputStream: InputStream, shouldDropHeader: Boolean, - tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] { + tokenizer: CsvParser, + checkHeader: Array[String] => Unit = _ => ())( + convert: Array[String] => T) = new Iterator[T] { tokenizer.beginParsing(inputStream) private var nextRecord = { if (shouldDropHeader) { - tokenizer.parseNext() + val firstRecord = tokenizer.parseNext() + checkHeader(firstRecord) } tokenizer.parseNext() } @@ -289,21 +303,11 @@ private[csv] object UnivocityParser { */ def parseIterator( lines: Iterator[String], - shouldDropHeader: Boolean, parser: UnivocityParser, schema: StructType): Iterator[InternalRow] = { val options = parser.options - val linesWithoutHeader = if (shouldDropHeader) { - // Note that if there are only comments in the first block, the header would probably - // be not dropped. - CSVUtils.dropHeaderLine(lines, options) - } else { - lines - } - - val filteredLines: Iterator[String] = - CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) + val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options) val safeParser = new FailureSafeParser[String]( input => Seq(parser.parse(input)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index b4e5d169066d9..eea966d30948b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.StructType * Options for the JDBC data source. */ class JDBCOptions( - @transient private val parameters: CaseInsensitiveMap[String]) + @transient val parameters: CaseInsensitiveMap[String]) extends Serializable { import JDBCOptions._ @@ -65,11 +65,31 @@ class JDBCOptions( // Required parameters // ------------------------------------------------------------ require(parameters.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.") - require(parameters.isDefinedAt(JDBC_TABLE_NAME), s"Option '$JDBC_TABLE_NAME' is required.") // a JDBC URL val url = parameters(JDBC_URL) - // name of table - val table = parameters(JDBC_TABLE_NAME) + // table name or a table subquery. + val tableOrQuery = (parameters.get(JDBC_TABLE_NAME), parameters.get(JDBC_QUERY_STRING)) match { + case (Some(name), Some(subquery)) => + throw new IllegalArgumentException( + s"Both '$JDBC_TABLE_NAME' and '$JDBC_QUERY_STRING' can not be specified at the same time." + ) + case (None, None) => + throw new IllegalArgumentException( + s"Option '$JDBC_TABLE_NAME' or '$JDBC_QUERY_STRING' is required." + ) + case (Some(name), None) => + if (name.isEmpty) { + throw new IllegalArgumentException(s"Option '$JDBC_TABLE_NAME' can not be empty.") + } else { + name.trim + } + case (None, Some(subquery)) => + if (subquery.isEmpty) { + throw new IllegalArgumentException(s"Option `$JDBC_QUERY_STRING` can not be empty.") + } else { + s"(${subquery}) __SPARK_GEN_JDBC_SUBQUERY_NAME_${curId.getAndIncrement()}" + } + } // ------------------------------------------------------------ // Optional parameters @@ -89,6 +109,10 @@ class JDBCOptions( // the number of partitions val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt) + // the number of seconds the driver will wait for a Statement object to execute to the given + // number of seconds. Zero means there is no limit. + val queryTimeout = parameters.getOrElse(JDBC_QUERY_TIMEOUT, "0").toInt + // ------------------------------------------------------------ // Optional parameters only for reading // ------------------------------------------------------------ @@ -105,6 +129,20 @@ class JDBCOptions( s"When reading JDBC data sources, users need to specify all or none for the following " + s"options: '$JDBC_PARTITION_COLUMN', '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', " + s"and '$JDBC_NUM_PARTITIONS'") + + require(!(parameters.get(JDBC_QUERY_STRING).isDefined && partitionColumn.isDefined), + s""" + |Options '$JDBC_QUERY_STRING' and '$JDBC_PARTITION_COLUMN' can not be specified together. + |Please define the query using `$JDBC_TABLE_NAME` option instead and make sure to qualify + |the partition columns using the supplied subquery alias to resolve any ambiguity. + |Example : + |spark.read.format("jdbc") + | .option("dbtable", "(select c1, c2 from t1) as subq") + | .option("partitionColumn", "subq.c1" + | .load() + """.stripMargin + ) + val fetchSize = { val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt require(size >= 0, @@ -145,7 +183,30 @@ class JDBCOptions( val sessionInitStatement = parameters.get(JDBC_SESSION_INIT_STATEMENT) } +class JdbcOptionsInWrite( + @transient override val parameters: CaseInsensitiveMap[String]) + extends JDBCOptions(parameters) { + + import JDBCOptions._ + + def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + + def this(url: String, table: String, parameters: Map[String, String]) = { + this(CaseInsensitiveMap(parameters ++ Map( + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> table))) + } + + require( + parameters.get(JDBC_TABLE_NAME).isDefined, + s"Option '$JDBC_TABLE_NAME' is required. " + + s"Option '$JDBC_QUERY_STRING' is not applicable while writing.") + + val table = parameters(JDBC_TABLE_NAME) +} + object JDBCOptions { + private val curId = new java.util.concurrent.atomic.AtomicLong(0L) private val jdbcOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { @@ -155,11 +216,13 @@ object JDBCOptions { val JDBC_URL = newOption("url") val JDBC_TABLE_NAME = newOption("dbtable") + val JDBC_QUERY_STRING = newOption("query") val JDBC_DRIVER_CLASS = newOption("driver") val JDBC_PARTITION_COLUMN = newOption("partitionColumn") val JDBC_LOWER_BOUND = newOption("lowerBound") val JDBC_UPPER_BOUND = newOption("upperBound") val JDBC_NUM_PARTITIONS = newOption("numPartitions") + val JDBC_QUERY_TIMEOUT = newOption("queryTimeout") val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 05326210f3242..1b3b17c75e756 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -51,12 +51,13 @@ object JDBCRDD extends Logging { */ def resolveTable(options: JDBCOptions): StructType = { val url = options.url - val table = options.table + val table = options.tableOrQuery val dialect = JdbcDialects.get(url) val conn: Connection = JdbcUtils.createConnectionFactory(options)() try { val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) try { + statement.setQueryTimeout(options.queryTimeout) val rs = statement.executeQuery() try { JdbcUtils.getSchema(rs, dialect, alwaysNullable = true) @@ -281,6 +282,7 @@ private[jdbc] class JDBCRDD( val statement = conn.prepareStatement(sql) logInfo(s"Executing sessionInitStatement: $sql") try { + statement.setQueryTimeout(options.queryTimeout) statement.execute() } finally { statement.close() @@ -294,10 +296,11 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) - val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause" + val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) + stmt.setQueryTimeout(options.queryTimeout) rs = stmt.executeQuery() val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b23e5a7722004..97e2d255cb7be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -22,10 +22,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Instructions on how to partition the table among workers. @@ -48,10 +50,17 @@ private[sql] object JDBCRelation extends Logging { * Null value predicate is added to the first partition where clause to include * the rows with null value for the partitions column. * + * @param schema resolved schema of a JDBC table * @param partitioning partition information to generate the where clause for each partition + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url * @return an array of partitions with where clause for each partition */ - def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { + def columnPartition( + schema: StructType, + partitioning: JDBCPartitioningInfo, + resolver: Resolver, + jdbcOptions: JDBCOptions): Array[Partition] = { if (partitioning == null || partitioning.numPartitions <= 1 || partitioning.lowerBound == partitioning.upperBound) { return Array[Partition](JDBCPartition(null, 0)) @@ -78,7 +87,10 @@ private[sql] object JDBCRelation extends Logging { // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. val stride: Long = upperBound / numPartitions - lowerBound / numPartitions - val column = partitioning.column + + val column = verifyAndGetNormalizedColumnName( + schema, partitioning.column, resolver, jdbcOptions) + var i: Int = 0 var currentValue: Long = lowerBound val ans = new ArrayBuffer[Partition]() @@ -99,10 +111,57 @@ private[sql] object JDBCRelation extends Logging { } ans.toArray } + + // Verify column name based on the JDBC resolved schema + private def verifyAndGetNormalizedColumnName( + schema: StructType, + columnName: String, + resolver: Resolver, + jdbcOptions: JDBCOptions): String = { + val dialect = JdbcDialects.get(jdbcOptions.url) + schema.map(_.name).find { fieldName => + resolver(fieldName, columnName) || + resolver(dialect.quoteIdentifier(fieldName), columnName) + }.map(dialect.quoteIdentifier).getOrElse { + throw new AnalysisException(s"User-defined partition column $columnName not " + + s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") + } + } + + /** + * Takes a (schema, table) specification and returns the table's Catalyst schema. + * If `customSchema` defined in the JDBC options, replaces the schema's dataType with the + * custom schema's type. + * + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url, table and other information. + * @return resolved Catalyst schema of a JDBC table + */ + def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = { + val tableSchema = JDBCRDD.resolveTable(jdbcOptions) + jdbcOptions.customSchema match { + case Some(customSchema) => JdbcUtils.getCustomSchema( + tableSchema, customSchema, resolver) + case None => tableSchema + } + } + + /** + * Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema. + */ + def apply( + parts: Array[Partition], + jdbcOptions: JDBCOptions)( + sparkSession: SparkSession): JDBCRelation = { + val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sparkSession) + } } private[sql] case class JDBCRelation( - parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) + override val schema: StructType, + parts: Array[Partition], + jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) extends BaseRelation with PrunedFilteredScan with InsertableRelation { @@ -111,15 +170,6 @@ private[sql] case class JDBCRelation( override val needConversion: Boolean = false - override val schema: StructType = { - val tableSchema = JDBCRDD.resolveTable(jdbcOptions) - jdbcOptions.customSchema match { - case Some(customSchema) => JdbcUtils.getCustomSchema( - tableSchema, customSchema, sparkSession.sessionState.conf.resolver) - case None => tableSchema - } - } - // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) @@ -139,12 +189,12 @@ private[sql] case class JDBCRelation( override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) + .jdbc(jdbcOptions.url, jdbcOptions.tableOrQuery, jdbcOptions.asProperties) } override def toString: String = { val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" else "" // credentials should not be included in the plan output, table information is sufficient. - s"JDBCRelation(${jdbcOptions.table})" + partitioningInfo + s"JDBCRelation(${jdbcOptions.tableOrQuery})" + partitioningInfo } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index cc506e51bd0c6..782d626c1573c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -48,8 +48,10 @@ class JdbcRelationProvider extends CreatableRelationProvider JDBCPartitioningInfo( partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get) } - val parts = JDBCRelation.columnPartition(partitionInfo) - JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession) + val resolver = sqlContext.conf.resolver + val schema = JDBCRelation.getSchema(resolver, jdbcOptions) + val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession) } override def createRelation( @@ -57,7 +59,7 @@ class JdbcRelationProvider extends CreatableRelationProvider mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = { - val options = new JDBCOptions(parameters) + val options = new JdbcOptionsInWrite(parameters) val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis val conn = JdbcUtils.createConnectionFactory(options)() @@ -73,7 +75,7 @@ class JdbcRelationProvider extends CreatableRelationProvider saveTable(df, tableSchema, isCaseSensitive, options) } else { // Otherwise, do not truncate the table, instead drop and recreate it - dropTable(conn, options.table) + dropTable(conn, options.table, options) createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } @@ -84,7 +86,8 @@ class JdbcRelationProvider extends CreatableRelationProvider case SaveMode.ErrorIfExists => throw new AnalysisException( - s"Table or view '${options.table}' already exists. SaveMode: ErrorIfExists.") + s"Table or view '${options.table}' already exists. " + + s"SaveMode: ErrorIfExists.") case SaveMode.Ignore => // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index e6dc2fda4eb1b..b81737eda475b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -67,7 +67,7 @@ object JdbcUtils extends Logging { /** * Returns true if the table already exists in the JDBC database. */ - def tableExists(conn: Connection, options: JDBCOptions): Boolean = { + def tableExists(conn: Connection, options: JdbcOptionsInWrite): Boolean = { val dialect = JdbcDialects.get(options.url) // Somewhat hacky, but there isn't a good way to identify whether a table exists for all @@ -76,6 +76,7 @@ object JdbcUtils extends Logging { Try { val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table)) try { + statement.setQueryTimeout(options.queryTimeout) statement.executeQuery() } finally { statement.close() @@ -86,9 +87,10 @@ object JdbcUtils extends Logging { /** * Drops a table from the JDBC database. */ - def dropTable(conn: Connection, table: String): Unit = { + def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = { val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(s"DROP TABLE $table") } finally { statement.close() @@ -98,10 +100,11 @@ object JdbcUtils extends Logging { /** * Truncates a table from the JDBC database without side effects. */ - def truncateTable(conn: Connection, options: JDBCOptions): Unit = { + def truncateTable(conn: Connection, options: JdbcOptionsInWrite): Unit = { val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(dialect.getTruncateQuery(options.table)) } finally { statement.close() @@ -252,8 +255,9 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(options.url) try { - val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table)) + val statement = conn.prepareStatement(dialect.getSchemaQuery(options.tableOrQuery)) try { + statement.setQueryTimeout(options.queryTimeout) Some(getSchema(statement.executeQuery(), dialect)) } catch { case _: SQLException => None @@ -596,7 +600,8 @@ object JdbcUtils extends Logging { insertStmt: String, batchSize: Int, dialect: JdbcDialect, - isolationLevel: Int): Iterator[Byte] = { + isolationLevel: Int, + options: JDBCOptions): Iterator[Byte] = { val conn = getConnection() var committed = false @@ -637,6 +642,9 @@ object JdbcUtils extends Logging { try { var rowCount = 0 + + stmt.setQueryTimeout(options.queryTimeout) + while (iterator.hasNext) { val row = iterator.next() var i = 0 @@ -801,7 +809,7 @@ object JdbcUtils extends Logging { df: DataFrame, tableSchema: Option[StructType], isCaseSensitive: Boolean, - options: JDBCOptions): Unit = { + options: JdbcOptionsInWrite): Unit = { val url = options.url val table = options.table val dialect = JdbcDialects.get(url) @@ -819,7 +827,8 @@ object JdbcUtils extends Logging { case _ => df } repartitionedDF.rdd.foreachPartition(iterator => savePartition( - getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel) + getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, + options) ) } @@ -829,7 +838,7 @@ object JdbcUtils extends Logging { def createTable( conn: Connection, df: DataFrame, - options: JDBCOptions): Unit = { + options: JdbcOptionsInWrite): Unit = { val strSchema = schemaString( df, options.url, options.createTableColumnTypes) val table = options.table @@ -841,6 +850,7 @@ object JdbcUtils extends Logging { val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions" val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(sql) } finally { statement.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 5769c09c9a1d9..2fee2128ba1f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -31,11 +31,12 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} -import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} +import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -92,32 +93,33 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: Dataset[String] = createBaseDataset( - sparkSession, inputPaths, parsedOptions.lineSeparator) + val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) + inferFromDataset(json, parsedOptions) } def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) - val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0)) - JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String) + val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd + val rowParser = parsedOptions.encoding.map { enc => + CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) + }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) + + SQLExecution.withSQLConfPropagated(json.sparkSession) { + JsonInferSchema.infer(rdd, parsedOptions, rowParser) + } } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - lineSeparator: Option[String]): Dataset[String] = { - val textOptions = lineSeparator.map { lineSep => - Map(TextOptions.LINE_SEPARATOR -> lineSep) - }.getOrElse(Map.empty[String, String]) - - val paths = inputPaths.map(_.getPath.toString) + parsedOptions: JSONOptions): Dataset[String] = { sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = paths, + paths = inputPaths.map(_.getPath.toString), className = classOf[TextFileFormat].getName, - options = textOptions + options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) .select("value").as(Encoders.STRING) } @@ -129,8 +131,12 @@ object TextInputJsonDataSource extends JsonDataSource { schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + val textParser = parser.options.encoding + .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text)) + .getOrElse(CreateJacksonParser.text(_: JsonFactory, _: Text)) + val safeParser = new FailureSafeParser[Text]( - input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + input => parser.parse(input, textParser, textToUTF8String), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) @@ -151,16 +157,24 @@ object MultiLineJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) + val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) - JsonInferSchema.infer(sampled, parsedOptions, createParser) + val parser = parsedOptions.encoding + .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) + .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) + + SQLExecution.withSQLConfPropagated(sparkSession) { + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + } } private def createBaseRdd( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): RDD[PortableDataStream] = { val paths = inputPaths.map(_.getPath) - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions( + parsedOptions.parameters)) val conf = job.getConfiguration val name = paths.mkString(",") FileInputFormat.setInputPaths(job, paths: _*) @@ -175,11 +189,18 @@ object MultiLineJsonDataSource extends JsonDataSource { .values } - private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { - val path = new Path(record.getPath()) - CreateJacksonParser.inputStream( - jsonFactory, - CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path)) + private def dataToInputStream(dataStream: PortableDataStream): InputStream = { + val path = new Path(dataStream.getPath()) + CodecStreams.createInputStreamWithCloseResource(dataStream.getConfiguration, path) + } + + private def createParser(jsonFactory: JsonFactory, stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(jsonFactory, dataToInputStream(stream)) + } + + private def createParser(enc: String, jsonFactory: JsonFactory, + stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(enc, jsonFactory, dataToInputStream(stream)) } override def readFile( @@ -194,9 +215,12 @@ object MultiLineJsonDataSource extends JsonDataSource { UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } } + val streamParser = parser.options.encoding + .map(enc => CreateJacksonParser.inputStream(enc, _: JsonFactory, _: InputStream)) + .getOrElse(CreateJacksonParser.inputStream(_: JsonFactory, _: InputStream)) val safeParser = new FailureSafeParser[InputStream]( - input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + input => parser.parse[InputStream](input, streamParser, partitionedFileString), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 0862c746fffad..a9241afba537b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.nio.charset.{Charset, StandardCharsets} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -24,11 +26,11 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { @@ -38,7 +40,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -50,7 +52,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -97,7 +99,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -142,6 +144,23 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => true + + case _ => false + } } private[json] class JsonOutputWriter( @@ -151,7 +170,18 @@ private[json] class JsonOutputWriter( context: TaskAttemptContext) extends OutputWriter with Logging { - private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + private val encoding = options.encoding match { + case Some(charsetName) => Charset.forName(charsetName) + case None => StandardCharsets.UTF_8 + } + + if (JSONOptionsInRead.blacklist.contains(encoding)) { + logWarning(s"The JSON file ($path) was written in the encoding ${encoding.displayName()}" + + " which can be read back by Spark only if multiLine is enabled.") + } + + private val writer = CodecStreams.createOutputStreamWriter( + context, new Path(path), encoding) // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 1de2ca2914c44..3a8c0add8c2f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -224,4 +224,21 @@ class OrcFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d8f47eec952de..3ec33b2f4b540 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -34,6 +34,7 @@ import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel import org.apache.parquet.hadoop.codec.CodecConfig import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType @@ -77,7 +78,6 @@ class ParquetFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) val conf = ContextUtil.getConfiguration(job) @@ -125,16 +125,17 @@ class ParquetFileFormat conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) // SPARK-15719: Disables writing Parquet summary files by default. - if (conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { - conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + if (conf.get(ParquetOutputFormat.JOB_SUMMARY_LEVEL) == null + && conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { + conf.setEnum(ParquetOutputFormat.JOB_SUMMARY_LEVEL, JobSummaryLevel.NONE) } - if (conf.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + if (ParquetOutputFormat.getJobSummaryLevel(conf) == JobSummaryLevel.NONE && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { // output summary is requested, but the class is not a Parquet Committer logWarning(s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + s" create job summaries. " + - s"Set Parquet option ${ParquetOutputFormat.ENABLE_JOB_SUMMARY} to false.") + s"Set Parquet option ${ParquetOutputFormat.JOB_SUMMARY_LEVEL} to NONE.") } new OutputWriterFactory { @@ -333,37 +334,27 @@ class ParquetFileFormat val enableVectorizedReader: Boolean = sqlConf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) - val enableRecordFilter: Boolean = - sparkSession.sessionState.conf.parquetRecordFilterEnabled - val timestampConversion: Boolean = - sparkSession.sessionState.conf.isParquetINT96TimestampConversion + val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled + val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize - val enableParquetFilterPushDown: Boolean = - sparkSession.sessionState.conf.parquetFilterPushDown + val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - // Try to push down filters when filter push-down is enabled. - val pushed = if (enableParquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(ParquetFilters.createFilter(requiredSchema, _)) - .reduceOption(FilterApi.and) - } else { - None - } - val fileSplit = new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val filePath = fileSplit.getPath val split = new org.apache.parquet.hadoop.ParquetInputSplit( - fileSplit.getPath, + filePath, fileSplit.getStart, fileSplit.getStart + fileSplit.getLength, fileSplit.getLength, @@ -371,12 +362,29 @@ class ParquetFileFormat null) val sharedConf = broadcastedHadoopConf.value.value + + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) + .getFileMetaData.getSchema + val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, + pushDownStringStartWith, pushDownInFilterThreshold) + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter(parquetSchema, _)) + .reduceOption(FilterApi.and) + } else { + None + } + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' // *only* if the file was created by something other than "parquet-mr", so check the actual // writer here for this file. We have to do this per-file, as each file in the table may // have different writers. def isCreatedByParquetMr(): Boolean = { - val footer = ParquetFileReader.readFooter(sharedConf, fileSplit.getPath, SKIP_ROW_GROUPS) + val footer = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) footer.getFileMetaData().getCreatedBy().startsWith("parquet-mr") } val convertTz = @@ -442,6 +450,21 @@ class ParquetFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } } object ParquetFileFormat extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index ccc8306866d68..0c146f2f6f915 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,188 +17,269 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.sql.Date +import java.lang.{Long => JLong} +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator} +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources -import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] object ParquetFilters { +private[parquet] class ParquetFilters( + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int) { + + private case class ParquetSchemaType( + originalType: OriginalType, + primitiveTypeName: PrimitiveTypeName, + decimalMetadata: DecimalMetadata) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, null) + private val ParquetByteType = ParquetSchemaType(INT_8, INT32, null) + private val ParquetShortType = ParquetSchemaType(INT_16, INT32, null) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, null) + private val ParquetLongType = ParquetSchemaType(null, INT64, null) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, null) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, null) + private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, null) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, null) + private val ParquetDateType = ParquetSchemaType(DATE, INT32, null) + private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, null) + private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, null) private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) } - private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => + private val makeEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetBooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) // Binary.fromString and Binary.fromByteArray don't accept null values - case StringType => + case ParquetStringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => + case ParquetBinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong]).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) } - private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => + private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetBooleanType => (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) + case ParquetLongType => (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => + case ParquetBinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong]).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) } - private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) - case LongType => + private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.lt( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.lt( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) } - private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => + private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.ltEq( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.ltEq( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) } - private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => + private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.gt( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gt( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) } - private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => + private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) + case ParquetLongType => (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => (n: String, v: Any) => FilterApi.gtEq( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gtEq( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) } /** * Returns a map from name of the column to the data type, if predicate push down applies. */ - private def getFieldMap(dataType: DataType): Map[String, DataType] = dataType match { - case StructType(fields) => + private def getFieldMap(dataType: MessageType): Map[String, ParquetSchemaType] = dataType match { + case m: MessageType => // Here we don't flatten the fields in the nested schema but just look up through // root fields. Currently, accessing to nested fields does not push down filters // and it does not support to create filters for them. - fields.map(f => f.name -> f.dataType).toMap - case _ => Map.empty[String, DataType] + m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => + f.getName -> ParquetSchemaType( + f.getOriginalType, f.getPrimitiveTypeName, f.getDecimalMetadata) + }.toMap + case _ => Map.empty[String, ParquetSchemaType] } /** * Converts data sources filters to Parquet filter predicates. */ - def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { + def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { val nameToType = getFieldMap(schema) // Parquet does not allow dots in the column name because dots are used as a column path @@ -208,6 +289,16 @@ private[parquet] object ParquetFilters { // See SPARK-20364. def canMakeFilterOn(name: String): Boolean = nameToType.contains(name) && !name.contains(".") + // All DataTypes that support `makeEq` can provide better performance. + def shouldConvertInPredicate(name: String): Boolean = nameToType(name) match { + case ParquetBooleanType | ParquetByteType | ParquetShortType | ParquetIntegerType + | ParquetLongType | ParquetFloatType | ParquetDoubleType | ParquetStringType + | ParquetBinaryType => true + case ParquetDateType if pushDownDate => true + case ParquetTimestampMicrosType | ParquetTimestampMillisType if pushDownTimestamp => true + case _ => false + } + // NOTE: // // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, @@ -271,6 +362,43 @@ private[parquet] object ParquetFilters { case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) + case sources.In(name, values) if canMakeFilterOn(name) && shouldConvertInPredicate(name) + && values.distinct.length <= pushDownInFilterThreshold => + values.distinct.flatMap { v => + makeEq.lift(nameToType(name)).map(_(name, v)) + }.reduceLeftOption(FilterApi.or) + + case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name) => + Option(prefix).map { v => + FilterApi.userDefined(binaryColumn(name), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + UTF8String.fromBytes(value.getBytes).startsWith( + UTF8String.fromBytes(strToBinary.getBytes)) + } + } + ) + } + case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index f36a89a4c3c5f..9cfc30725f03a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -81,7 +81,10 @@ object ParquetOptions { "uncompressed" -> CompressionCodecName.UNCOMPRESSED, "snappy" -> CompressionCodecName.SNAPPY, "gzip" -> CompressionCodecName.GZIP, - "lzo" -> CompressionCodecName.LZO) + "lzo" -> CompressionCodecName.LZO, + "lz4" -> CompressionCodecName.LZ4, + "brotli" -> CompressionCodecName.BROTLI, + "zstd" -> CompressionCodecName.ZSTD) def getParquetCompressionCodecName(name: String): String = { shortParquetCompressionCodecNames(name).name() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 0dea767840ed3..cab00251622b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -61,7 +61,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { case _: ClassNotFoundException => u case e: Exception => // the provider is valid, but failed to create a logical plan - u.failAnalysis(e.getMessage) + u.failAnalysis(e.getMessage, e) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index e93908da43535..8661a5395ac44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{DataType, StringType, StructType} import org.apache.spark.util.SerializableConfiguration /** @@ -47,11 +47,6 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { throw new AnalysisException( s"Text data source supports only a single column, and you have ${schema.size} columns.") } - val tpe = schema(0).dataType - if (tpe != StringType) { - throw new AnalysisException( - s"Text data source supports only a string column, but you have ${tpe.simpleString}.") - } } override def isSplitable( @@ -141,6 +136,9 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = + dataType == StringType } class TextOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 5c1a35434f7b5..e4e201995faa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.text -import java.nio.charset.StandardCharsets +import java.nio.charset.{Charset, StandardCharsets} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} @@ -41,13 +41,18 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti */ val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean - private val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { sep => - require(sep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") - sep + val encoding: Option[String] = parameters.get(ENCODING) + + val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { lineSep => + require(lineSep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + + lineSep } + // Note that the option 'lineSep' uses a different default value in read and write. - val lineSeparatorInRead: Option[Array[Byte]] = - lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.map(Charset.forName(_)).getOrElse(StandardCharsets.UTF_8)) + } val lineSeparatorInWrite: Array[Byte] = lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } @@ -55,5 +60,6 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti private[datasources] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" + val ENCODING = "encoding" val LINE_SEPARATOR = "lineSep" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala index 017a6737161a6..33079d5912506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -30,8 +30,8 @@ class DataSourcePartitioning( override val numPartitions: Int = partitioning.numPartitions() - override def satisfies(required: physical.Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: physical.Distribution): Boolean = { + super.satisfies0(required) || { required match { case d: physical.ClusteredDistribution if isCandidate(d.clustering) => val attrs = d.clustering.map(_.asInstanceOf[Attribute]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index f85971be394b1..8d6fb3820d420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -22,24 +22,25 @@ import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.InputPartition -class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T]) extends Partition with Serializable class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readerFactories: Seq[DataReaderFactory[T]]) + @transient private val inputPartitions: Seq[InputPartition[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) + inputPartitions.zipWithIndex.map { + case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition) }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition + .createPartitionReader() context.addTaskCompletionListener(_ => reader.close()) val iter = new Iterator[T] { private[this] var valuePrepared = false @@ -63,6 +64,6 @@ class DataSourceRDD[T: ClassTag]( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 2b282ffae2390..7613eb210c659 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -22,74 +22,36 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} import org.apache.spark.sql.types.StructType +/** + * A logical plan representing a data source v2 scan. + * + * @param source An instance of a [[DataSourceV2]] implementation. + * @param options The options for this scan. Used to create fresh [[DataSourceReader]]. + * @param userSpecifiedSchema The user-specified schema for this scan. Used to create fresh + * [[DataSourceReader]]. + */ case class DataSourceV2Relation( source: DataSourceV2, + output: Seq[AttributeReference], options: Map[String, String], - projection: Seq[AttributeReference], - filters: Option[Seq[Expression]] = None, - userSpecifiedSchema: Option[StructType] = None) + userSpecifiedSchema: Option[StructType]) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ - override def simpleString: String = "RelationV2 " + metadataString + override def pushedFilters: Seq[Expression] = Seq.empty - override lazy val schema: StructType = reader.readSchema() - - override lazy val output: Seq[AttributeReference] = { - // use the projection attributes to avoid assigning new ids. fields that are not projected - // will be assigned new ids, which is okay because they are not projected. - val attrMap = projection.map(a => a.name -> a).toMap - schema.map(f => attrMap.getOrElse(f.name, - AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) - } - - private lazy val v2Options: DataSourceOptions = makeV2Options(options) - - lazy val ( - reader: DataSourceReader, - unsupportedFilters: Seq[Expression], - pushedFilters: Seq[Expression]) = { - val newReader = userSpecifiedSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) - } - - DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - - val (remainingFilters, pushedFilters) = filters match { - case Some(filterSeq) => - DataSourceV2Relation.pushFilters(newReader, filterSeq) - case _ => - (Nil, Nil) - } - - (newReader, remainingFilters, pushedFilters) - } - - override def doCanonicalize(): LogicalPlan = { - val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] + override def simpleString: String = "RelationV2 " + metadataString - // override output with canonicalized output to avoid attempting to configure a reader - val canonicalOutput: Seq[AttributeReference] = this.output - .map(a => QueryPlan.normalizeExprId(a, projection)) + def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema) - new DataSourceV2Relation(c.source, c.options, c.projection) { - override lazy val output: Seq[AttributeReference] = canonicalOutput - } - } - - override def computeStats(): Statistics = reader match { + override def computeStats(): Statistics = newReader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => @@ -97,9 +59,7 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - // projection is used to maintain id assignment. - // if projection is not set, use output so the copy is not equal to the original - copy(projection = projection.map(_.newInstance())) + copy(output = output.map(_.newInstance())) } } @@ -121,6 +81,8 @@ case class StreamingDataSourceV2Relation( override def simpleString: String = "Streaming RelationV2 " + metadataString + override def pushedFilters: Seq[Expression] = Nil + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) // TODO: unify the equal/hashCode implementation for all data source v2 query plans. @@ -177,73 +139,26 @@ object DataSourceV2Relation { source.getClass.getSimpleName } } - } - - private def makeV2Options(options: Map[String, String]): DataSourceOptions = { - new DataSourceOptions(options.asJava) - } - private def schema( - source: DataSourceV2, - v2Options: DataSourceOptions, - userSchema: Option[StructType]): StructType = { - val reader = userSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) + def createReader( + options: Map[String, String], + userSpecifiedSchema: Option[StructType]): DataSourceReader = { + val v2Options = new DataSourceOptions(options.asJava) + userSpecifiedSchema match { + case Some(s) => + asReadSupportWithSchema.createReader(s, v2Options) + case _ => + asReadSupport.createReader(v2Options) + } } - reader.readSchema() } def create( source: DataSourceV2, options: Map[String, String], - filters: Option[Seq[Expression]] = None, - userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema) - } - - private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { - reader match { - case projectionSupport: SupportsPushDownRequiredColumns => - projectionSupport.pruneColumns(struct) - case _ => - } - } - - private def pushFilters( - reader: DataSourceReader, - filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - reader match { - case catalystFilterSupport: SupportsPushDownCatalystFilters => - ( - catalystFilterSupport.pushCatalystFilters(filters.toArray), - catalystFilterSupport.pushedCatalystFilters() - ) - - case filterSupport: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, Filter] = filters.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet - val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) => - unhandledFilters.contains(f) - } - - (nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq) - - case _ => (filters, Nil) - } + userSpecifiedSchema: Option[StructType]): DataSourceV2Relation = { + val reader = source.createReader(options, userSpecifiedSchema) + DataSourceV2Relation( + source, reader.readSchema().toAttributes, options, userSpecifiedSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 3a5e7bf89e142..c6a7684bf6ab0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -41,6 +41,7 @@ case class DataSourceV2ScanExec( output: Seq[AttributeReference], @transient source: DataSourceV2, @transient options: Map[String, String], + @transient pushedFilters: Seq[Expression], @transient reader: DataSourceReader) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { @@ -58,13 +59,13 @@ case class DataSourceV2ScanExec( } override def outputPartitioning: physical.Partitioning = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 => + case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchPartitions.size == 1 => SinglePartition - case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 => + case r: SupportsScanColumnarBatch if !r.enableBatchRead() && partitions.size == 1 => SinglePartition - case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 => + case r if !r.isInstanceOf[SupportsScanColumnarBatch] && partitions.size == 1 => SinglePartition case s: SupportsReportPartitioning => @@ -74,19 +75,19 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala + private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala case _ => - reader.createDataReaderFactories().asScala.map { - new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] + reader.planInputPartitions().asScala.map { + new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow] } } - private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match { + private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match { case r: SupportsScanColumnarBatch if r.enableBatchRead() => assert(!reader.isInstanceOf[ContinuousReader], "continuous stream reader does not support columnar read yet.") - r.createBatchDataReaderFactories().asScala + r.planBatchInputPartitions().asScala } private lazy val inputRDD: RDD[InternalRow] = reader match { @@ -94,15 +95,18 @@ case class DataSourceV2ScanExec( EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readerFactories.size)) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) - .asInstanceOf[RDD[InternalRow]] + .askSync[Unit](SetReaderPartitions(partitions.size)) + new ContinuousDataSourceRDD( + sparkContext, + sqlContext.conf.continuousStreamingExecutorQueueSize, + sqlContext.conf.continuousStreamingExecutorPollIntervalMs, + partitions).asInstanceOf[RDD[InternalRow]] case r: SupportsScanColumnarBatch if r.enableBatchRead() => - new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, batchPartitions).asInstanceOf[RDD[InternalRow]] case _ => - new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, partitions).asInstanceOf[RDD[InternalRow]] } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) @@ -127,19 +131,22 @@ case class DataSourceV2ScanExec( } } -class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType) - extends DataReaderFactory[UnsafeRow] { +class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType) + extends InputPartition[UnsafeRow] { - override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations + override def preferredLocations: Array[String] = partition.preferredLocations - override def createDataReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader( - rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind()) + override def createPartitionReader: InputPartitionReader[UnsafeRow] = { + new RowToUnsafeInputPartitionReader( + partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind()) } } -class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) - extends DataReader[UnsafeRow] { +class RowToUnsafeInputPartitionReader( + val rowReader: InputPartitionReader[Row], + encoder: ExpressionEncoder[Row]) + + extends InputPartitionReader[UnsafeRow] { override def next: Boolean = rowReader.next diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1ac9572de6412..2a7f1de2c7c19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,21 +17,142 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.Strategy -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import scala.collection.mutable + +import org.apache.spark.sql.{sources, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader object DataSourceV2Strategy extends Strategy { + + /** + * Pushes down filters to the data source reader + * + * @return pushed filter and post-scan filters. + */ + private def pushFilters( + reader: DataSourceReader, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + reader match { + case r: SupportsPushDownCatalystFilters => + val postScanFilters = r.pushCatalystFilters(filters.toArray) + val pushedFilters = r.pushedCatalystFilters() + (pushedFilters, postScanFilters) + + case r: SupportsPushDownFilters => + // A map from translated data source filters to original catalyst filter expressions. + val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = DataSourceStrategy.translateFilter(filterExpr) + if (translated.isDefined) { + translatedFilterToExpr(translated.get) = filterExpr + } else { + untranslatableExprs += filterExpr + } + } + + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = r.pushFilters(translatedFilterToExpr.keys.toArray) + .map(translatedFilterToExpr) + // The filters which are marked as pushed to this data source + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + (pushedFilters, untranslatableExprs ++ postScanFilters) + + case _ => (Nil, filters) + } + } + + /** + * Applies column pruning to the data source, w.r.t. the references of the given expressions. + * + * @return new output attributes after column pruning. + */ + // TODO: nested column pruning. + private def pruneColumns( + reader: DataSourceReader, + relation: DataSourceV2Relation, + exprs: Seq[Expression]): Seq[AttributeReference] = { + reader match { + case r: SupportsPushDownRequiredColumns => + val requiredColumns = AttributeSet(exprs.flatMap(_.references)) + val neededOutput = relation.output.filter(requiredColumns.contains) + if (neededOutput != relation.output) { + r.pruneColumns(neededOutput.toStructType) + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + r.readSchema().toAttributes.map { + // We have to keep the attribute id during transformation. + a => a.withExprId(nameToAttr(a.name).exprId) + } + } else { + relation.output + } + + case _ => relation.output + } + } + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + val reader = relation.newReader() + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. + // `postScanFilters` need to be evaluated after the scan. + // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. + val (pushedFilters, postScanFilters) = pushFilters(reader, filters) + val output = pruneColumns(reader, relation, project ++ postScanFilters) + logInfo( + s""" + |Pushing operators to ${relation.source.getClass} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + |Output: ${output.mkString(", ")} + """.stripMargin) + + val scan = DataSourceV2ScanExec( + output, relation.source, relation.options, pushedFilters, reader) + + val filterCondition = postScanFilters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + + val withProjection = if (withFilter.output != project) { + ProjectExec(project, withFilter) + } else { + withFilter + } + + withProjection :: Nil case r: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case WriteToContinuousDataSource(writer, query) => + WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + + case Repartition(1, false, child) => + val isContinuous = child.collectFirst { + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r + }.isDefined + + if (isContinuous) { + ContinuousCoalesceExec(1, planLater(child)) :: Nil + } else { + Nil + } + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index aed55a429bfd7..97e6c6d702acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -19,11 +19,9 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.util.Utils /** @@ -49,27 +47,22 @@ trait DataSourceV2StringFormat { def options: Map[String, String] /** - * The created data source reader. Here we use it to get the filters that has been pushed down - * so far, itself doesn't take part in the equals/hashCode. + * The filters which have been pushed to the data source. */ - def reader: DataSourceReader - - private lazy val filters = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Set.empty - } + def pushedFilters: Seq[Expression] private def sourceName: String = source match { case registered: DataSourceRegister => registered.shortName() - case _ => source.getClass.getSimpleName.stripSuffix("$") + // source.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + case _ => Utils.getSimpleName(source.getClass) } def metadataString: String = { val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - if (filters.nonEmpty) { - entries += "Filters" -> filters.mkString("[", ", ", "]") + if (pushedFilters.nonEmpty) { + entries += "Filters" -> pushedFilters.mkString("[", ", ", "]") } // TODO: we should only display some standard options like path, table, etc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala deleted file mode 100644 index f23d228567241..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} -import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule - -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { - override def apply( - plan: LogicalPlan): LogicalPlan = plan transformUp { - // PhysicalOperation guarantees that filters are deterministic; no need to check - case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) => - // merge the filters - val filters = relation.filters match { - case Some(existing) => - existing ++ newFilters - case _ => - newFilters - } - - val projectAttrs = project.map(_.toAttribute) - val projectSet = AttributeSet(project.flatMap(_.references)) - val filterSet = AttributeSet(filters.flatMap(_.references)) - - val projection = if (filterSet.subsetOf(projectSet) && - AttributeSet(projectAttrs) == projectSet) { - // When the required projection contains all of the filter columns and column pruning alone - // can produce the required projection, push the required projection. - // A final projection may still be needed if the data source produces a different column - // order or if it cannot prune all of the nested columns. - projectAttrs - } else { - // When there are filter columns not already in the required projection or when the required - // projection is more complicated than column pruning, base column pruning on the set of - // all columns needed by both. - (projectSet ++ filterSet).toSeq - } - - val newRelation = relation.copy( - projection = projection.asInstanceOf[Seq[AttributeReference]], - filters = Some(filters)) - - // Add a Filter for any filters that could not be pushed - val unpushedFilter = newRelation.unsupportedFilters.reduceLeftOption(And) - val filtered = unpushedFilter.map(Filter(_, newRelation)).getOrElse(newRelation) - - // Add a Project to ensure the output matches the required projection - if (newRelation.output != projectAttrs) { - Project(project, filtered) - } else { - filtered - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index e80b44c1cdc66..b1148c0f62f7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -29,10 +29,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution} -import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} +import org.apache.spark.sql.execution.streaming.MicroBatchExecution import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -65,25 +63,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e s"The input RDD has ${messages.length} partitions.") try { - val runTask = writer match { - // This case means that we're doing continuous processing. In microbatch streaming, the - // StreamWriter is wrapped in a MicroBatchWriter, which is executed as a normal batch. - case w: StreamWriter => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - sparkContext.env) - .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) - - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.runContinuous(writeTask, context, iter) - case _ => - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator) - } - sparkContext.runJob( rdd, - runTask, + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator), rdd.partitions.indices, (index, message: WriterCommitMessage) => { messages(index) = message @@ -91,14 +74,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e } ) - if (!writer.isInstanceOf[StreamWriter]) { - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") - } + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") } catch { - case _: InterruptedException if writer.isInstanceOf[StreamWriter] => - // Interruption is how continuous queries are ended, so accept and ignore the exception. case cause: Throwable => logError(s"Data source writer $writer is aborting.") try { @@ -111,8 +90,6 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e } logError(s"Data source writer $writer aborted.") cause match { - // Do not wrap interruption exceptions that will be handled by streaming specially. - case _ if StreamExecution.isInterruptionException(cause) => throw cause // Only wrap non fatal exceptions. case NonFatal(e) => throw new SparkException("Writing job aborted.", e) case _ => throw cause @@ -130,23 +107,29 @@ object DataWritingSparkTask extends Logging { iter: Iterator[InternalRow], useCommitCoordinator: Boolean): WriterCommitMessage = { val stageId = context.stageId() + val stageAttempt = context.stageAttemptNumber() val partId = context.partitionId() + val taskId = context.taskAttemptId() val attemptId = context.attemptNumber() val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") - val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong) + val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - iter.foreach(dataWriter.write) + while (iter.hasNext) { + dataWriter.write(iter.next()) + } val msg = if (useCommitCoordinator) { val coordinator = SparkEnv.get.outputCommitCoordinator - val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId) + val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId) if (commitAuthorized) { - logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.") + logInfo(s"Commit authorized for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") dataWriter.commit() } else { - val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit" + val message = s"Commit denied for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)" logInfo(message) // throwing CommitDeniedException will trigger the catch block for abort throw new CommitDeniedException(message, stageId, partId, attemptId) @@ -157,60 +140,20 @@ object DataWritingSparkTask extends Logging { dataWriter.commit() } - logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.") + logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") msg })(catchBlock = { // If there is an error, abort this writer - logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.") + logError(s"Aborting commit for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") dataWriter.abort() - logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") + logError(s"Aborted commit for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") }) } - - def runContinuous( - writeTask: DataWriterFactory[InternalRow], - context: TaskContext, - iter: Iterator[InternalRow]): WriterCommitMessage = { - val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - SparkEnv.get) - val currentMsg: WriterCommitMessage = null - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - do { - var dataWriter: DataWriter[InternalRow] = null - // write the data and commit this writer. - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - try { - dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) - while (iter.hasNext) { - dataWriter.write(iter.next()) - } - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") - epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) - ) - currentEpoch += 1 - } catch { - case _: InterruptedException => - // Continuous shutdown always involves an interrupt. Just finish the task. - } - })(catchBlock = { - // If there is an error, abort this writer. We enter this callback in the middle of - // rethrowing an exception, so runContinuous will stop executing at this point. - logError(s"Writer for partition ${context.partitionId()} is aborting.") - if (dataWriter != null) dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") - }) - } while (!context.isInterrupted()) - - currentMsg - } } class InternalRowDataWriterFactory( @@ -219,10 +162,10 @@ class InternalRowDataWriterFactory( override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId), + rowWriterFactory.createDataWriter(partitionId, taskId, epochId), RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index daea6c39624d6..a80673c705f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.execution.exchange +import java.util.concurrent.TimeoutException + import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ +import scala.util.control.NonFatal import org.apache.spark.{broadcast, SparkException} import org.apache.spark.launcher.SparkLauncher @@ -30,7 +33,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SparkFatalException, ThreadUtils} /** * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of @@ -69,7 +72,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types @@ -111,12 +114,18 @@ case class BroadcastExchangeExec( SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) broadcasted } catch { + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult + // will catch this exception and re-throw the wrapped fatal throwable. case oe: OutOfMemoryError => - throw new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " + + throw new SparkFatalException( + new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " + s"all worker nodes. As a workaround, you can either disable broadcast by setting " + s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " + s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value") - .initCause(oe.getCause) + .initCause(oe.getCause)) + case e if !NonFatal(e) => + throw new SparkFatalException(e) } } }(BroadcastExchangeExec.executionContext) @@ -133,7 +142,16 @@ case class BroadcastExchangeExec( } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + try { + ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + } catch { + case ex: TimeoutException => + logError(s"Could not execute broadcast in ${timeout.toSeconds} secs.", ex) + throw new SparkException(s"Could not execute broadcast in ${timeout.toSeconds} secs. " + + s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " + + s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1", + ex) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e3d28388c5470..d96ecbaa48029 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.exchange +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ @@ -227,9 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { val leftKeysBuffer = ArrayBuffer[Expression]() val rightKeysBuffer = ArrayBuffer[Expression]() + val pickedIndexes = mutable.Set[Int]() + val keysAndIndexes = currentOrderOfKeys.zipWithIndex expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + val index = keysAndIndexes.find { case (e, idx) => + // As we may have the same key used many times, we need to filter out its occurrence we + // have already used. + e.semanticEquals(expression) && !pickedIndexes.contains(idx) + }.map(_._2).get + pickedIndexes += index leftKeysBuffer.append(leftKeys(index)) rightKeysBuffer.append(rightKeys(index)) }) @@ -270,14 +278,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { * partitioning of the join nodes' children. */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { - plan.transformUp { - case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, - right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - + plan match { case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) @@ -288,6 +289,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + + case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 09f79a2de0ba0..1a5b7599bb7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -24,7 +24,7 @@ import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf @@ -70,7 +70,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan } override def outputPartitioning: Partitioning = child.outputPartitioning match { - case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case e: Expression => updateAttr(e).asInstanceOf[Partitioning] case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 78f11ca8d8c78..051e610eb2705 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -232,16 +232,16 @@ class ExchangeCoordinator( // number of post-shuffle partitions. val partitionStartIndices = if (mapOutputStatistics.length == 0) { - None + Array.empty[Int] } else { - Some(estimatePartitionStartIndices(mapOutputStatistics)) + estimatePartitionStartIndices(mapOutputStatistics) } var k = 0 while (k < numExchanges) { val exchange = exchanges(k) val rdd = - exchange.preparePostShuffleRDD(shuffleDependencies(k), partitionStartIndices) + exchange.preparePostShuffleRDD(shuffleDependencies(k), Some(partitionStartIndices)) newPostShuffleRDDs.put(exchange, rdd) k += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 6fa716d9fadee..0da0e8610c392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -183,7 +184,7 @@ case class BroadcastHashJoinExec( val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val javaType = CodeGenerator.javaType(a.dataType) - val code = s""" + val code = code""" |boolean $isNull = true; |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 1465346eb802d..20ce01f4ce8cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -557,7 +557,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap def append(key: Long, row: UnsafeRow): Unit = { val sizeInBytes = row.getSizeInBytes if (sizeInBytes >= (1 << SIZE_BITS)) { - sys.error("Does not support row that is larger than 256M") + throw new UnsupportedOperationException("Does not support row that is larger than 256M") } if (key < minKey) { @@ -567,19 +567,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap maxKey = key } - // There is 8 bytes for the pointer to next value - if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) { - val used = page.length - if (used >= (1 << 30)) { - sys.error("Can not build a HashedRelation that is larger than 8G") - } - ensureAcquireMemory(used * 8L * 2) - val newPage = new Array[Long](used * 2) - Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, - cursor - Platform.LONG_ARRAY_OFFSET) - page = newPage - freeMemory(used * 8L) - } + grow(row.getSizeInBytes) // copy the bytes of UnsafeRow val offset = cursor @@ -615,7 +603,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap growArray() } else if (numKeys > array.length / 2 * 0.75) { // The fill ratio should be less than 0.75 - sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys") + throw new UnsupportedOperationException( + "Cannot build HashedRelation with more than 1/3 billions unique keys") } } } else { @@ -626,6 +615,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } + private def grow(inputRowSize: Int): Unit = { + // There is 8 bytes for the pointer to next value + val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8 + if (neededNumWords > page.length) { + if (neededNumWords > (1 << 30)) { + throw new UnsupportedOperationException( + "Can not build a HashedRelation that is larger than 8G") + } + val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30)) + ensureAcquireMemory(newNumWords * 8L) + val newPage = new Array[Long](newNumWords.toInt) + Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, + cursor - Platform.LONG_ARRAY_OFFSET) + val used = page.length + page = newPage + freeMemory(used * 8L) + } + } + private def growArray(): Unit = { var old_array = array val n = array.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d8261f0f33b61..f4b9d132122e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ @@ -521,7 +522,7 @@ case class SortMergeJoinExec( if (a.nullable) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |$isNull = $leftRow.isNullAt($i); |$value = $isNull ? $defaultValue : ($valueCode); """.stripMargin @@ -533,7 +534,7 @@ case class SortMergeJoinExec( (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), leftVarsDecl) } else { - val code = s"$value = $valueCode;" + val code = code"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 77b907870d678..b4f0ae1eb1a18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat import java.util.Locale +import java.util.concurrent.atomic.LongAdder import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo @@ -32,40 +33,45 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. */ class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { + // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 // values before calculate max, min, etc. - private[this] var _value = initValue - private var _zeroValue = initValue + private[this] val _value = new LongAdder + private val _zeroValue = initValue + _value.add(initValue) override def copy(): SQLMetric = { - val newAcc = new SQLMetric(metricType, _value) - newAcc._zeroValue = initValue + val newAcc = new SQLMetric(metricType, initValue) + newAcc.add(_value.sum()) newAcc } - override def reset(): Unit = _value = _zeroValue + override def reset(): Unit = this.set(_zeroValue) override def merge(other: AccumulatorV2[Long, Long]): Unit = other match { - case o: SQLMetric => _value += o.value + case o: SQLMetric => _value.add(o.value) case _ => throw new UnsupportedOperationException( s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def isZero(): Boolean = _value == _zeroValue + override def isZero(): Boolean = _value.sum() == _zeroValue - override def add(v: Long): Unit = _value += v + override def add(v: Long): Unit = _value.add(v) // We can set a double value to `SQLMetric` which stores only long value, if it is // average metrics. def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v) - def set(v: Long): Unit = _value = v + def set(v: Long): Unit = { + _value.reset() + _value.add(v) + } - def +=(v: Long): Unit = _value += v + def +=(v: Long): Unit = _value.add(v) - override def value: Long = _value + override def value: Long = _value.sum() // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { @@ -153,7 +159,7 @@ object SQLMetrics { Seq.fill(3)(0L) } else { val sorted = validValues.sorted - Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Seq(sorted.head, sorted(validValues.length / 2), sorted(validValues.length - 1)) } metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric)) } @@ -173,7 +179,8 @@ object SQLMetrics { Seq.fill(4)(0L) } else { val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Seq(sorted.sum, sorted.head, sorted(validValues.length / 2), + sorted(validValues.length - 1)) } metric.map(strFormat) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 8e01e8e56a5bd..d00f6f042d6e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -81,7 +82,7 @@ case class AggregateInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip @@ -135,10 +136,14 @@ case class AggregateInPandasExec( } val columnarBatchIter = new ArrowPythonRunner( - pyFuncs, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(projectedRowIter, context.partitionId(), context) + pyFuncs, + bufferSize, + reuseWorker, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + argOffsets, + aggInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) val joinedAttributes = groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index c4de214679ae4..0bc21c0986e69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType /** @@ -63,7 +64,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) protected override def evaluate( funcs: Seq[ChainedPythonFunctions], @@ -80,10 +81,14 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( - funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(batchIter, context.partitionId(), context) + funcs, + bufferSize, + reuseWorker, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + argOffsets, + schema, + sessionLocalTimeZone, + pythonRunnerConf).compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5fcdcddca7d51..ca665652f204d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -45,7 +45,7 @@ class ArrowPythonRunner( argOffsets: Array[Array[Int]], schema: StructType, timeZoneId: String, - respectTimeZone: Boolean) + conf: Map[String, String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { @@ -58,31 +58,28 @@ class ArrowPythonRunner( new WriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - if (respectTimeZone) { - PythonRDD.writeUTF(timeZoneId, dataOut) - } else { - dataOut.writeInt(SpecialLengths.NULL) + + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(conf.size) + for ((k, v) <- conf) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) } + + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val arrowWriter = ArrowWriter.create(root) - - context.addTaskCompletionListener { _ => - root.close() - allocator.close() - } - - val writer = new ArrowStreamWriter(root, null, dataOut) - writer.start() Utils.tryWithSafeFinally { + val arrowWriter = ArrowWriter.create(root) + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + while (inputIterator.hasNext) { val nextBatch = inputIterator.next() @@ -94,8 +91,21 @@ class ArrowPythonRunner( writer.writeBatch() arrowWriter.reset() } - } { + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. writer.end() + } { + // If we close root and allocator in TaskCompletionListener, there could be a race + // condition where the writer thread keeps writing to the VectorSchemaRoot while + // it's being closed by the TaskCompletion listener. + // Closing root and allocator here is cleaner because root and allocator is owned + // by the writer thread and is only visible to the writer thread. + // + // If the writer thread is interrupted by TaskCompletionListener, it should either + // (1) in the try block, in which case it will get an InterruptedException when + // performing io, and goes into the finally block or (2) in the finally block, + // in which case it will ignore the interruption and close the resources. root.close() allocator.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 9d56f48249982..1e096100f7f43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -39,7 +39,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || - PythonUDF.isGroupAggPandasUDF(e) || + PythonUDF.isGroupedAggPandasUDF(e) || agg.groupingExpressions.exists(_.semanticEquals(e)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 513e174c7733e..f5a563baf52df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType /** @@ -77,7 +78,7 @@ case class FlatMapGroupsInPandasExec( val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) // Deduplicate the grouping attributes. // If a grouping attribute also appears in data attributes, then we don't need to send the @@ -139,10 +140,14 @@ case class FlatMapGroupsInPandasExec( val context = TaskContext.get() val columnarBatchIter = new ArrowPythonRunner( - chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(grouped, context.partitionId(), context) + chainedFunc, + bufferSize, + reuseWorker, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + argOffsets, + dedupSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala new file mode 100644 index 0000000000000..a58773122922f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{NextIterator, Utils} + +class PythonForeachWriter(func: PythonFunction, schema: StructType) + extends ForeachWriter[UnsafeRow] { + + private lazy val context = TaskContext.get() + private lazy val buffer = new PythonForeachWriter.UnsafeRowBuffer( + context.taskMemoryManager, new File(Utils.getLocalDir(SparkEnv.get.conf)), schema.fields.length) + private lazy val inputRowIterator = buffer.iterator + + private lazy val inputByteIterator = { + EvaluatePython.registerPicklers() + val objIterator = inputRowIterator.map { row => EvaluatePython.toJava(row, schema) } + new SerDeUtil.AutoBatchedPickler(objIterator) + } + + private lazy val pythonRunner = { + val conf = SparkEnv.get.conf + val bufferSize = conf.getInt("spark.buffer.size", 65536) + val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) + PythonRunner(func, bufferSize, reuseWorker) + } + + private lazy val outputIterator = + pythonRunner.compute(inputByteIterator, context.partitionId(), context) + + override def open(partitionId: Long, version: Long): Boolean = { + outputIterator // initialize everything + TaskContext.get.addTaskCompletionListener { _ => buffer.close() } + true + } + + override def process(value: UnsafeRow): Unit = { + buffer.add(value) + } + + override def close(errorOrNull: Throwable): Unit = { + buffer.allRowsAdded() + if (outputIterator.hasNext) outputIterator.next() // to throw python exception if there was one + } +} + +object PythonForeachWriter { + + /** + * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeachWriter. + * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader + * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python + * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator + * are blocking, that is, it blocks until new data is available or all data has been added. + * + * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a practically unlimited queue + * across memory and local disk. However, HybridRowQueue is designed to be used only with + * EvalPythonExec where the reader is always behind the the writer, that is, the reader does not + * try to read n+1 rows if the writer has only written n rows at any point of time. This + * assumption is not true for PythonForeachWriter where rows may be added at a different rate as + * they are consumed by the python worker. Hence, to maintain the invariant of the reader being + * behind the writer while using HybridRowQueue, the buffer does the following + * - Keeps a count of the rows in the HybridRowQueue + * - Blocks the buffer's consuming iterator when the count is 0 so that the reader does not + * try to read more rows than what has been written. + * + * The implementation of the blocking iterator (ReentrantLock, Condition, etc.) has been borrowed + * from that of ArrayBlockingQueue. + */ + class UnsafeRowBuffer(taskMemoryManager: TaskMemoryManager, tempDir: File, numFields: Int) + extends Logging { + private val queue = HybridRowQueue(taskMemoryManager, tempDir, numFields) + private val lock = new ReentrantLock() + private val unblockRemove = lock.newCondition() + + // All of these are guarded by `lock` + private var count = 0L + private var allAdded = false + private var exception: Throwable = null + + val iterator = new NextIterator[UnsafeRow] { + override protected def getNext(): UnsafeRow = { + val row = remove() + if (row == null) finished = true + row + } + override protected def close(): Unit = { } + } + + def add(row: UnsafeRow): Unit = withLock { + assert(queue.add(row), s"Failed to add row to HybridRowQueue while sending data to Python" + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count += 1 + unblockRemove.signal() + logTrace(s"Added $row, $count left") + } + + private def remove(): UnsafeRow = withLock { + while (count == 0 && !allAdded && exception == null) { + unblockRemove.await(100, TimeUnit.MILLISECONDS) + } + + // If there was any error in the adding thread, then rethrow it in the removing thread + if (exception != null) throw exception + + if (count > 0) { + val row = queue.remove() + assert(row != null, "HybridRowQueue.remove() returned null " + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count -= 1 + logTrace(s"Removed $row, $count left") + row + } else { + null + } + } + + def allRowsAdded(): Unit = withLock { + allAdded = true + unblockRemove.signal() + } + + def close(): Unit = { queue.close() } + + private def withLock[T](f: => T): T = { + lock.lockInterruptibly() + try { f } catch { + case e: Throwable => + if (exception == null) exception = e + throw e + } finally { lock.unlock() } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala new file mode 100644 index 0000000000000..628029b13a6c3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + +case class WindowInPandasExec( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) extends UnaryExecNode { + + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = { + if (partitionSpec.isEmpty) { + // Only show warning when the number of bytes is larger than 100 MB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") + AllTuples :: Nil + } else { + ClusteredDistribution(partitionSpec) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { + val references = expressions.zipWithIndex.map { case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) + } + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) + UnsafeProjection.create( + child.output ++ patchedWindowExpression, + child.output) + } + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + // Extract window expressions and window functions + val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e }) + + val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF]) + + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + // Schema of input rows to the python runner + val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + inputRDD.mapPartitionsInternal { iter => + val context = TaskContext.get() + + val grouped = if (partitionSpec.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, partitionSpec, child.output) + } + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + context.addTaskCompletionListener { _ => + queue.close() + } + + val inputProj = UnsafeProjection.create(allInputs, child.output) + val pythonInput = grouped.map { case (_, rows) => + rows.map { row => + queue.add(row.asInstanceOf[UnsafeRow]) + inputProj(row) + } + } + + val windowFunctionResult = new ArrowPythonRunner( + pyFuncs, + bufferSize, + reuseWorker, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + argOffsets, + windowInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(pythonInput, context.partitionId(), context) + + val joined = new JoinedRow + val resultProj = createResultProjection(expressions) + + windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, windowOutput) + resultProj(joinedRow) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 80769d728b8f1..8e82cccbc8fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -97,6 +97,18 @@ case class FlatMapGroupsWithStateExec( override def keyExpressions: Seq[Attribute] = groupingAttributes + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + timeoutConf match { + case ProcessingTimeTimeout => + true // Always run batches to process timeouts + case EventTimeTimeout => + // Process another non-data batch only if the watermark has changed in this executed plan + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + case _ => + false + } + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver @@ -126,7 +138,6 @@ case class FlatMapGroupsWithStateExec( case _ => iter } - // Generate a iterator that returns the rows grouped by the grouping function // Note that this code ensures that the filtering for timeout occurs only after // all the data has been processed. This is to ensure that the timeout information of all @@ -194,11 +205,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.getRange(None, None).filter { rowPair => + val timingOutPairs = store.getRange(None, None).filter { rowPair => val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { rowPair => + timingOutPairs.flatMap { rowPair => callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) } } else Iterator.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 1a83c884d55bd..6ae7f2869b0f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -59,7 +59,8 @@ class IncrementalExecution( StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: StreamingRelationStrategy :: - StreamingDeduplicationStrategy :: Nil + StreamingDeduplicationStrategy :: + StreamingGlobalLimitStrategy(outputMode) :: Nil } private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) @@ -134,8 +135,12 @@ class IncrementalExecution( stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - Some(offsetSeqMetadata.batchWatermarkMs)) - ) + Some(offsetSeqMetadata.batchWatermarkMs))) + + case l: StreamingGlobalLimitExec => + l.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + outputMode = Some(outputMode)) } } @@ -143,4 +148,14 @@ class IncrementalExecution( /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } + + /** + * Should the MicroBatchExecution run another batch based on this execution and the current + * updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + executedPlan.collect { + case p: StateStoreWriter => p.shouldRunAnotherBatch(newMetadata) + }.exists(_ == true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6e231970f4a22..45c43f549d24f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -61,6 +61,8 @@ class MicroBatchExecution( case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } + private var watermarkTracker: WatermarkTracker = _ + override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, "logicalPlan must be initialized in QueryExecutionThread " + @@ -124,44 +126,90 @@ class MicroBatchExecution( _logicalPlan } + /** + * Signifies whether current batch (i.e. for the batch `currentBatchId`) has been constructed + * (i.e. written to the offsetLog) and is ready for execution. + */ + private var isCurrentBatchConstructed = false + + /** + * Signals to the thread executing micro-batches that it should stop running after the next + * batch. This method blocks until the thread stops running. + */ + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state.set(TERMINATED) + if (queryExecutionThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) + queryExecutionThread.interrupt() + queryExecutionThread.join() + // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak + sparkSession.sparkContext.cancelJobGroup(runId.toString) + } + logInfo(s"Query $prettyIdString was stopped") + } + /** * Repeatedly attempts to run batches as data arrives. */ protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - triggerExecutor.execute(() => { - startTrigger() + val noDataBatchesEnabled = + sparkSessionForStream.sessionState.conf.streamingNoDataMicroBatchesEnabled + + triggerExecutor.execute(() => { if (isActive) { + var currentBatchHasNewData = false // Whether the current batch had new data + + startTrigger() + reportTimeTaken("triggerExecution") { + // We'll do this initialization only once every start / restart if (currentBatchId < 0) { - // We'll do this initialization only once populateStartOffsets(sparkSessionForStream) - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() + logInfo(s"Stream started from $committedOffsets") } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") + + // Set this before calling constructNextBatch() so any Spark jobs executed by sources + // while getting new data have the correct description + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) + + // Try to construct the next batch. This will return true only if the next batch is + // ready and runnable. Note that the current batch may be runnable even without + // new data to process as `constructNextBatch` may decide to run a batch for + // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data + // is available or not. + if (!isCurrentBatchConstructed) { + isCurrentBatchConstructed = constructNextBatch(noDataBatchesEnabled) + } + + // Record the trigger offset range for progress reporting *before* processing the batch + recordTriggerOffsets(from = committedOffsets, to = availableOffsets) + + // Remember whether the current batch has data or not. This will be required later + // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed + // to false as the batch would have already processed the available data. + currentBatchHasNewData = isNewDataAvailable + + currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable) + if (isCurrentBatchConstructed) { + if (currentBatchHasNewData) updateStatusMessage("Processing new data") + else updateStatusMessage("No new data but cleaning up state") runBatch(sparkSessionForStream) + } else { + updateStatusMessage("Waiting for data to arrive") } } - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) - if (dataAvailable) { - // Update committed offsets. - commitLog.add(currentBatchId) - committedOffsets ++= availableOffsets - logDebug(s"batch ${currentBatchId} committed") - // We'll increase currentBatchId after we complete processing current batch's data + + finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded + + // If the current batch has been executed, then increment the batch id and reset flag. + // Otherwise, there was no data to execute the batch and sleep for some time + if (isCurrentBatchConstructed) { currentBatchId += 1 - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) - } + isCurrentBatchConstructed = false + } else Thread.sleep(pollingDelayMs) } updateStatusMessage("Waiting for next trigger") isActive @@ -196,6 +244,7 @@ class MicroBatchExecution( /* First assume that we are re-executing the latest known batch * in the offset log */ currentBatchId = latestBatchId + isCurrentBatchConstructed = true availableOffsets = nextOffsets.toStreamProgress(sources) /* Initialize committed offsets to a committed batch, which at this * is the second latest batch id in the offset log. */ @@ -211,6 +260,8 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) + watermarkTracker.setWatermark(metadata.batchWatermarkMs) } /* identify the current batch id: if commit log indicates we successfully processed the @@ -233,9 +284,9 @@ class MicroBatchExecution( // here, so we do nothing here. } currentBatchId = latestCommittedBatchId + 1 + isCurrentBatchConstructed = false committedOffsets ++= availableOffsets // Construct a new batch be recomputing availableOffsets - constructNextBatch() } else if (latestCommittedBatchId < latestBatchId - 1) { logWarning(s"Batch completion log latest batch id is " + s"${latestCommittedBatchId}, which is not trailing " + @@ -243,19 +294,19 @@ class MicroBatchExecution( } case None => logInfo("no commit log present") } - logDebug(s"Resuming at batch $currentBatchId with committed offsets " + + logInfo(s"Resuming at batch $currentBatchId with committed offsets " + s"$committedOffsets and available offsets $availableOffsets") case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 - constructNextBatch() + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) } } /** * Returns true if there is any new data available to be processed. */ - private def dataAvailable: Boolean = { + private def isNewDataAvailable: Boolean = { availableOffsets.exists { case (source, available) => committedOffsets @@ -266,93 +317,65 @@ class MicroBatchExecution( } /** - * Queries all of the sources to see if any new data is available. When there is new data the - * batchId counter is incremented and a new log entry is written with the newest offsets. + * Attempts to construct a batch according to: + * - Availability of new data + * - Need for timeouts and state cleanups in stateful operators + * + * Returns true only if the next batch should be executed. + * + * Here is the high-level logic on how this constructs the next batch. + * - Check each source whether new data is available + * - Updated the query's metadata and check using the last execution whether there is any need + * to run another batch (for state clean up, etc.) + * - If either of the above is true, then construct the next batch by committing to the offset + * log that range of offsets that the next batch will process. */ - private def constructNextBatch(): Unit = { - // Check to see what new data is available. - val hasNewData = { - awaitProgressLock.lock() - try { - // Generate a map from each unique source to the next available offset. - val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { - case s: Source => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - (s, s.getOffset) - } - case s: MicroBatchReader => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("setOffsetRange") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) - } - - val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } - (s, Option(currentOffset)) - }.toMap - availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) - - if (dataAvailable) { - true - } else { - noNewData = true - false - } - } finally { - awaitProgressLock.unlock() - } - } - if (hasNewData) { - var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs - // Update the eventTime watermarks if we find any in the plan. - if (lastExecution != null) { - lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec => e - }.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") - val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = watermarkMsMap.get(index) - if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - watermarkMsMap.put(index, newWatermarkMs) - } - - // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!watermarkMsMap.isDefinedAt(index)) { - watermarkMsMap.put(index, 0) - } + private def constructNextBatch(noDataBatchesEnabled: Boolean): Boolean = withProgressLocked { + if (isCurrentBatchConstructed) return true + + // Generate a map from each unique source to the next available offset. + val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { + case s: Source => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + (s, s.getOffset) } - - // Update the global watermark to the minimum of all watermark nodes. - // This is the safest option, because only the global watermark is fault-tolerant. Making - // it the minimum of all individual watermarks guarantees it will never advance past where - // any individual watermark operator would be if it were in a plan by itself. - if(!watermarkMsMap.isEmpty) { - val newWatermarkMs = watermarkMsMap.minBy(_._2)._2 - if (newWatermarkMs > batchWatermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - batchWatermarkMs = newWatermarkMs - } else { - logDebug( - s"Event time didn't move: $newWatermarkMs < " + - s"$batchWatermarkMs") - } + case s: MicroBatchReader => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("setOffsetRange") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) } - } - offsetSeqMetadata = offsetSeqMetadata.copy( - batchWatermarkMs = batchWatermarkMs, - batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds + val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } + (s, Option(currentOffset)) + }.toMap + availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) + + // Update the query metadata + offsetSeqMetadata = offsetSeqMetadata.copy( + batchWatermarkMs = watermarkTracker.currentWatermark, + batchTimestampMs = triggerClock.getTimeMillis()) + + // Check whether next batch should be constructed + val lastExecutionRequiresAnotherBatch = noDataBatchesEnabled && + Option(lastExecution).exists(_.shouldRunAnotherBatch(offsetSeqMetadata)) + val shouldConstructNextBatch = isNewDataAvailable || lastExecutionRequiresAnotherBatch + logTrace( + s"noDataBatchesEnabled = $noDataBatchesEnabled, " + + s"lastExecutionRequiresAnotherBatch = $lastExecutionRequiresAnotherBatch, " + + s"isNewDataAvailable = $isNewDataAvailable, " + + s"shouldConstructNextBatch = $shouldConstructNextBatch") + + if (shouldConstructNextBatch) { + // Commit the next batch offset range to the offset log updateStatusMessage("Writing offsets to log") reportTimeTaken("walCommit") { - assert(offsetLog.add( - currentBatchId, + assert(offsetLog.add(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)), s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") logInfo(s"Committed offsets for batch $currentBatchId. " + @@ -373,7 +396,7 @@ class MicroBatchExecution( reader.commit(reader.deserializeOffset(off.json)) } } else { - throw new IllegalStateException(s"batch $currentBatchId doesn't exist") + throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist") } } @@ -384,15 +407,12 @@ class MicroBatchExecution( commitLog.purge(currentBatchId - minLogEntriesToMaintain) } } + noNewData = false } else { - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. - awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() - } + noNewData = true + awaitProgressLockCondition.signalAll() } + shouldConstructNextBatch } /** @@ -400,6 +420,8 @@ class MicroBatchExecution( * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. */ private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { + logDebug(s"Running batch $currentBatchId") + // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { @@ -513,17 +535,17 @@ class MicroBatchExecution( } } - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. + withProgressLocked { + commitLog.add(currentBatchId) + committedOffsets ++= availableOffsets awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() } + watermarkTracker.updateWatermark(lastExecution.executedPlan) + logDebug(s"Completed batch ${currentBatchId}") } /** Execute a function while locking the stream from making an progress */ - private[sql] def withProgressLocked(f: => Unit): Unit = { + private[sql] def withProgressLocked[T](f: => T): T = { awaitProgressLock.lock() try { f diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java index 80aa5505db991..43ad4b3384ec3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java @@ -19,8 +19,8 @@ /** * This is an internal, deprecated interface. New source implementations should use the - * org.apache.spark.sql.sources.v2.reader.Offset class, which is the one that will be supported - * in the long term. + * org.apache.spark.sql.sources.v2.reader.streaming.Offset class, which is the one that will be + * supported in the long term. * * This class will be removed in a future release. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 73945b39b8967..1ae3f36c152cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS} +import org.apache.spark.sql.internal.SQLConf._ /** * An ordered collection of offsets, used to track the progress of processing data from one or more @@ -39,7 +39,9 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet * cannot be serialized). */ def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = { - assert(sources.size == offsets.size) + assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " + + s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " + + s"Cannot continue.") new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } } @@ -84,7 +86,22 @@ case class OffsetSeqMetadata( object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) - private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS) + private val relevantSQLConfs = Seq( + SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY) + + /** + * Default values of relevant configurations that are used for backward compatibility. + * As new configurations are added to the metadata, existing checkpoints may not have those + * confs. The values in this list ensures that the confs without recovered values are + * set to a default value that ensure the same behavior of the streaming query as it was before + * the restart. + * + * Note, that this is optional; set values here if you *have* to override existing session conf + * with a specific default value for ensuring same behavior of the query as before. + */ + private val relevantSQLConfDefaultValues = Map[String, String]( + STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) @@ -113,8 +130,22 @@ object OffsetSeqMetadata extends Logging { case None => // For backward compatibility, if a config was not recorded in the offset log, - // then log it, and let the existing conf value in SparkSession prevail. - logWarning (s"Conf '$confKey' was not found in the offset log, using existing value") + // then either inject a default value (if specified in `relevantSQLConfDefaultValues`) or + // let the existing conf value in SparkSession prevail. + relevantSQLConfDefaultValues.get(confKey) match { + + case Some(defaultValue) => + sessionConf.set(confKey, defaultValue) + logWarning(s"Conf '$confKey' was not found in the offset log, " + + s"using default value '$defaultValue'") + + case None => + val valueStr = sessionConf.getOption(confKey).map { v => + s" Using existing session conf value '$v'." + }.getOrElse { " No value set in session conf." } + logWarning(s"Conf '$confKey' was not found in the offset log. $valueStr") + + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d1e5be9c12762..47f4b52e6e34c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -54,8 +56,6 @@ trait ProgressReporter extends Logging { protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution protected def newData: Map[BaseStreamingSource, LogicalPlan] - protected def availableOffsets: StreamProgress - protected def committedOffsets: StreamProgress protected def sources: Seq[BaseStreamingSource] protected def sink: BaseStreamingSink protected def offsetSeqMetadata: OffsetSeqMetadata @@ -66,8 +66,11 @@ trait ProgressReporter extends Logging { // Local timestamps and counters. private var currentTriggerStartTimestamp = -1L private var currentTriggerEndTimestamp = -1L + private var currentTriggerStartOffsets: Map[BaseStreamingSource, String] = _ + private var currentTriggerEndOffsets: Map[BaseStreamingSource, String] = _ // TODO: Restore this from the checkpoint when possible. private var lastTriggerStartTimestamp = -1L + private val currentDurationsMs = new mutable.HashMap[String, Long]() /** Flag that signals whether any error with input metrics have already been logged */ @@ -112,9 +115,20 @@ trait ProgressReporter extends Logging { lastTriggerStartTimestamp = currentTriggerStartTimestamp currentTriggerStartTimestamp = triggerClock.getTimeMillis() currentStatus = currentStatus.copy(isTriggerActive = true) + currentTriggerStartOffsets = null + currentTriggerEndOffsets = null currentDurationsMs.clear() } + /** + * Record the offsets range this trigger will process. Call this before updating + * `committedOffsets` in `StreamExecution` to make sure that the correct range is recorded. + */ + protected def recordTriggerOffsets(from: StreamProgress, to: StreamProgress): Unit = { + currentTriggerStartOffsets = from.mapValues(_.json) + currentTriggerEndOffsets = to.mapValues(_.json) + } + private def updateProgress(newProgress: StreamingQueryProgress): Unit = { progressBuffer.synchronized { progressBuffer += newProgress @@ -128,6 +142,7 @@ trait ProgressReporter extends Logging { /** Finalizes the query progress and adds it to list of recent status updates. */ protected def finishTrigger(hasNewData: Boolean): Unit = { + assert(currentTriggerStartOffsets != null && currentTriggerEndOffsets != null) currentTriggerEndTimestamp = triggerClock.getTimeMillis() val executionStats = extractExecutionStats(hasNewData) @@ -141,12 +156,12 @@ trait ProgressReporter extends Logging { } logDebug(s"Execution stats: $executionStats") - val sourceProgress = sources.map { source => + val sourceProgress = sources.distinct.map { source => val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, - startOffset = committedOffsets.get(source).map(_.json).orNull, - endOffset = availableOffsets.get(source).map(_.json).orNull, + startOffset = currentTriggerStartOffsets.get(source).orNull, + endOffset = currentTriggerEndOffsets.get(source).orNull, numInputRows = numRecords, inputRowsPerSecond = numRecords / inputTimeSec, processedRowsPerSecond = numRecords / processingTimeSec @@ -207,62 +222,126 @@ trait ProgressReporter extends Logging { return ExecutionStats(Map.empty, stateOperators, watermarkTimestamp) } - // We want to associate execution plan leaves to sources that generate them, so that we match - // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. - // Consider the translation from the streaming logical plan to the final executed plan. - // - // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan - // - // 1. We keep track of streaming sources associated with each leaf in the trigger's logical plan - // - Each logical plan leaf will be associated with a single streaming source. - // - There can be multiple logical plan leaves associated with a streaming source. - // - There can be leaves not associated with any streaming source, because they were - // generated from a batch source (e.g. stream-batch joins) - // - // 2. Assuming that the executed plan has same number of leaves in the same order as that of - // the trigger logical plan, we associate executed plan leaves with corresponding - // streaming sources. - // - // 3. For each source, we sum the metrics of the associated execution plan leaves. - // - val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => - logicalPlan.collectLeaves().map { leaf => leaf -> source } + val numInputRows = extractSourceToNumInputRows() + + val eventTimeStats = lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => + val stats = e.eventTimeStats.value + Map( + "max" -> stats.max, + "min" -> stats.min, + "avg" -> stats.avg.toLong).mapValues(formatTimestamp) + }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp + + ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } + + /** Extract number of input sources for each streaming source in plan */ + private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = { + + import java.util.IdentityHashMap + import scala.collection.JavaConverters._ + + def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = { + tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source } - val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming - val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() - val numInputRows: Map[BaseStreamingSource, Long] = + + val onlyDataSourceV2Sources = { + // Check whether the streaming query's logical plan has only V2 data sources + val allStreamingLeaves = + logicalPlan.collect { case s: StreamingExecutionRelation => s } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] } + } + + if (onlyDataSourceV2Sources) { + // DataSourceV2ScanExec is the execution plan leaf that is responsible for reading data + // from a V2 source and has a direct reference to the V2 source that generated it. Each + // DataSourceV2ScanExec records the number of rows it has read using SQLMetrics. However, + // just collecting all DataSourceV2ScanExec nodes and getting the metric is not correct as + // a DataSourceV2ScanExec instance may be referred to in the execution plan from two (or + // even multiple times) points and considering it twice will leads to double counting. We + // can't dedup them using their hashcode either because two different instances of + // DataSourceV2ScanExec can have the same hashcode but account for separate sets of + // records read, and deduping them to consider only one of them would be undercounting the + // records read. Therefore the right way to do this is to consider the unique instances of + // DataSourceV2ScanExec (using their identity hash codes) and get metrics from them. + // Hence we calculate in the following way. + // + // 1. Collect all the unique DataSourceV2ScanExec instances using IdentityHashMap. + // + // 2. Extract the source and the number of rows read from the DataSourceV2ScanExec instanes. + // + // 3. Multiple DataSourceV2ScanExec instance may refer to the same source (can happen with + // self-unions or self-joins). Add up the number of rows for each unique source. + val uniqueStreamingExecLeavesMap = + new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() + + lastExecution.executedPlan.collectLeaves().foreach { + case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] => + uniqueStreamingExecLeavesMap.put(s, s) + case _ => + } + + val sourceToInputRowsTuples = + uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => + val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + val source = execLeaf.reader.asInstanceOf[BaseStreamingSource] + source -> numRows + }.toSeq + logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) + sumRows(sourceToInputRowsTuples) + } else { + + // Since V1 source do not generate execution plan leaves that directly link with source that + // generated it, we can only do a best-effort association between execution plan leaves to the + // sources. This is known to fail in a few cases, see SPARK-24050. + // + // We want to associate execution plan leaves to sources that generate them, so that we match + // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. + // Consider the translation from the streaming logical plan to the final executed plan. + // + // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan + // + // 1. We keep track of streaming sources associated with each leaf in trigger's logical plan + // - Each logical plan leaf will be associated with a single streaming source. + // - There can be multiple logical plan leaves associated with a streaming source. + // - There can be leaves not associated with any streaming source, because they were + // generated from a batch source (e.g. stream-batch joins) + // + // 2. Assuming that the executed plan has same number of leaves in the same order as that of + // the trigger logical plan, we associate executed plan leaves with corresponding + // streaming sources. + // + // 3. For each source, we sum the metrics of the associated execution plan leaves. + // + val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => + logicalPlan.collectLeaves().map { leaf => leaf -> source } + } + val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming + val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source } } - val sourceToNumInputRows = execLeafToSource.map { case (execLeaf, source) => + val sourceToInputRowsTuples = execLeafToSource.map { case (execLeaf, source) => val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) source -> numRows } - sourceToNumInputRows.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source + sumRows(sourceToInputRowsTuples) } else { if (!metricWarningLogged) { def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}" + logWarning( "Could not report metrics as number leaves in trigger logical plan did not match that" + - s" of the execution plan:\n" + - s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + - s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") + s" of the execution plan:\n" + + s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + + s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") metricWarningLogged = true } Map.empty } - - val eventTimeStats = lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => - val stats = e.eventTimeStats.value - Map( - "max" -> stats.max, - "min" -> stats.min, - "avg" -> stats.avg.toLong).mapValues(formatTimestamp) - }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp - - ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } } /** Records the duration of running `body` for the next query progress update. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3fc8c7887896a..290de873c5cfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -378,24 +378,6 @@ abstract class StreamExecution( } } - /** - * Signals to the thread executing micro-batches that it should stop running after the next - * batch. This method blocks until the thread stops running. - */ - override def stop(): Unit = { - // Set the state to TERMINATED so that the batching thread knows that it was interrupted - // intentionally - state.set(TERMINATED) - if (queryExecutionThread.isAlive) { - sparkSession.sparkContext.cancelJobGroup(runId.toString) - queryExecutionThread.interrupt() - queryExecutionThread.join() - // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak - sparkSession.sparkContext.cancelJobGroup(runId.toString) - } - logInfo(s"Query $prettyIdString was stopped") - } - /** * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala new file mode 100644 index 0000000000000..bf4af60c8cf03 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.streaming.state.StateStoreOps +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType} +import org.apache.spark.util.CompletionIterator + +/** + * A physical operator for executing a streaming limit, which makes sure no more than streamLimit + * rows are returned. This operator is meant for streams in Append mode only. + */ +case class StreamingGlobalLimitExec( + streamLimit: Long, + child: SparkPlan, + stateInfo: Option[StatefulOperatorStateInfo] = None, + outputMode: Option[OutputMode] = None) + extends UnaryExecNode with StateStoreWriter { + + private val keySchema = StructType(Array(StructField("key", NullType))) + private val valueSchema = StructType(Array(StructField("value", LongType))) + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + assert(outputMode.isDefined && outputMode.get == InternalOutputModes.Append, + "StreamingGlobalLimitExec is only valid for streams in Append output mode") + + child.execute().mapPartitionsWithStateStore( + getStateInfo, + keySchema, + valueSchema, + indexOrdinal = None, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null))) + val numOutputRows = longMetric("numOutputRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + val updatesStartTimeNs = System.nanoTime + + val preBatchRowCount: Long = Option(store.get(key)).map(_.getLong(0)).getOrElse(0L) + var cumulativeRowCount = preBatchRowCount + + val result = iter.filter { r => + val x = cumulativeRowCount < streamLimit + if (x) { + cumulativeRowCount += 1 + } + x + } + + CompletionIterator[InternalRow, Iterator[InternalRow]](result, { + if (cumulativeRowCount > preBatchRowCount) { + numUpdatedStateRows += 1 + numOutputRows += cumulativeRowCount - preBatchRowCount + store.put(key, getValueRow(cumulativeRowCount)) + } + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) + }) + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = AllTuples :: Nil + + private def getValueRow(value: Long): UnsafeRow = { + UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index f02d3a2c3733f..24195b5657e8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -66,6 +66,7 @@ case class StreamingExecutionRelation( output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = source.toString @@ -97,6 +98,7 @@ case class StreamingRelationV2( output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = sourceName @@ -116,6 +118,7 @@ case class ContinuousExecutionRelation( output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = source.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index fa7c8ee906ecd..50cf971e4ec3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -167,8 +167,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil + HashClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + HashClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output @@ -187,6 +187,17 @@ case class StreamingSymmetricHashJoinExec( s"${getClass.getSimpleName} should not take $x as the JoinType") } + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + val watermarkUsedForStateCleanup = + stateWatermarkPredicates.left.nonEmpty || stateWatermarkPredicates.right.nonEmpty + + // Latest watermark value is more than that used in this previous executed plan + val watermarkHasChanged = + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + + watermarkUsedForStateCleanup && watermarkHasChanged + } + protected override def doExecute(): RDD[InternalRow] = { val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) @@ -319,8 +330,7 @@ case class StreamingSymmetricHashJoinExec( // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal // needs to be done greedily by immediately consuming the returned iterator. val cleanupIter = joinType match { - case Inner => - leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() case LeftOuter => rightSideJoiner.removeOldState() case RightOuter => leftSideJoiner.removeOldState() case _ => throwBadJoinTypeException() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala new file mode 100644 index 0000000000000..7b30db44a2090 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf + +/** + * Policy to define how to choose a new global watermark value if there are + * multiple watermark operators in a streaming query. + */ +sealed trait MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long +} + +object MultipleWatermarkPolicy { + val DEFAULT_POLICY_NAME = "min" + + def apply(policyName: String): MultipleWatermarkPolicy = { + policyName.toLowerCase match { + case DEFAULT_POLICY_NAME => MinWatermark + case "max" => MaxWatermark + case _ => + throw new IllegalArgumentException(s"Could not recognize watermark policy '$policyName'") + } + } +} + +/** + * Policy to choose the *min* of the operator watermark values as the global watermark value. + * Note that this is the safe (hence default) policy as the global watermark will advance + * only if all the individual operator watermarks have advanced. In other words, in a + * streaming query with multiple input streams and watermarks defined on all of them, + * the global watermark will advance as slowly as the slowest input. So if there is watermark + * based state cleanup or late-data dropping, then this policy is the most conservative one. + */ +case object MinWatermark extends MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = { + assert(operatorWatermarks.nonEmpty) + operatorWatermarks.min + } +} + +/** + * Policy to choose the *min* of the operator watermark values as the global watermark value. So the + * global watermark will advance if any of the individual operator watermarks has advanced. + * In other words, in a streaming query with multiple input streams and watermarks defined on all + * of them, the global watermark will advance as fast as the fastest input. So if there is watermark + * based state cleanup or late-data dropping, then this policy is the most aggressive one and + * may lead to unexpected behavior if the data of the slow stream is delayed. + */ +case object MaxWatermark extends MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = { + assert(operatorWatermarks.nonEmpty) + operatorWatermarks.max + } +} + +/** Tracks the watermark value of a streaming query based on a given `policy` */ +case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { + private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() + private var globalWatermarkMs: Long = 0 + + def setWatermark(newWatermarkMs: Long): Unit = synchronized { + globalWatermarkMs = newWatermarkMs + } + + def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { + val watermarkOperators = executedPlan.collect { + case e: EventTimeWatermarkExec => e + } + if (watermarkOperators.isEmpty) return + + watermarkOperators.zipWithIndex.foreach { + case (e, index) if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs + val prevWatermarkMs = operatorToWatermarkMap.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + operatorToWatermarkMap.put(index, newWatermarkMs) + } + + // Populate 0 if we haven't seen any data yet for this watermark node. + case (_, index) => + if (!operatorToWatermarkMap.isDefinedAt(index)) { + operatorToWatermarkMap.put(index, 0) + } + } + + // Update the global watermark to the minimum of all watermark nodes. + // This is the safest option, because only the global watermark is fault-tolerant. Making + // it the minimum of all individual watermarks guarantees it will never advance past where + // any individual watermark operator would be if it were in a plan by itself. + val chosenGlobalWatermark = policy.chooseGlobalWatermark(operatorToWatermarkMap.values.toSeq) + if (chosenGlobalWatermark > globalWatermarkMs) { + logInfo(s"Updating event-time watermark from $globalWatermarkMs to $chosenGlobalWatermark ms") + globalWatermarkMs = chosenGlobalWatermark + } else { + logDebug(s"Event time watermark didn't move: $chosenGlobalWatermark < $globalWatermarkMs") + } + } + + def currentWatermark: Long = synchronized { globalWatermarkMs } +} + +object WatermarkTracker { + def apply(conf: RuntimeConfig): WatermarkTracker = { + // If the session has been explicitly configured to use non-default policy then use it, + // otherwise use the default `min` policy as thats the safe thing to do. + // When recovering from a checkpoint location, it is expected that the `conf` will already + // be configured with the value present in the checkpoint. If there is no policy explicitly + // saved in the checkpoint (e.g., old checkpoints), then the default `min` policy is enforced + // through defaults specified in OffsetSeqMetadata.setSessionConf(). + val policyName = conf.get( + SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY, MultipleWatermarkPolicy.DEFAULT_POLICY_NAME) + new WatermarkTracker(MultipleWatermarkPolicy(policyName)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala new file mode 100644 index 0000000000000..5f60343bacfaa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark.{HashPartitioner, SparkEnv} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD} + +/** + * Physical plan for coalescing a continuous processing plan. + * + * Currently, only coalesces to a single partition are supported. `numPartitions` must be 1. + */ +case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends SparkPlan { + override def output: Seq[Attribute] = child.output + + override def children: Seq[SparkPlan] = child :: Nil + + override def outputPartitioning: Partitioning = SinglePartition + + override def doExecute(): RDD[InternalRow] = { + assert(numPartitions == 1) + new ContinuousCoalesceRDD( + sparkContext, + numPartitions, + conf.continuousStreamingExecutorQueueSize, + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_INTERVAL_KEY).toLong, + child.execute()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala new file mode 100644 index 0000000000000..ba85b355f974f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.continuous.shuffle._ +import org.apache.spark.util.ThreadUtils + +case class ContinuousCoalesceRDDPartition( + index: Int, + endpointName: String, + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long) + extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(endpointName, receiver) + + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } + // This flag will be flipped on the executors to indicate that the threads processing + // partitions of the write-side RDD have been started. These will run indefinitely + // asynchronously as epochs of the coalesce RDD complete on the read side. + private[continuous] var writersInitialized: Boolean = false +} + +/** + * RDD for continuous coalescing. Asynchronously writes all partitions of `prev` into a local + * continuous shuffle, and then reads them in the task thread using `reader`. + */ +class ContinuousCoalesceRDD( + context: SparkContext, + numPartitions: Int, + readerQueueSize: Int, + epochIntervalMs: Long, + prev: RDD[InternalRow]) + extends RDD[InternalRow](context, Nil) { + + // When we support more than 1 target partition, we'll need to figure out how to pass in the + // required partitioner. + private val outputPartitioner = new HashPartitioner(1) + + private val readerEndpointNames = (0 until numPartitions).map { i => + s"ContinuousCoalesceRDD-part$i-${UUID.randomUUID()}" + } + + override def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousCoalesceRDDPartition( + partIndex, + readerEndpointNames(partIndex), + readerQueueSize, + prev.getNumPartitions, + epochIntervalMs) + }.toArray + } + + private lazy val threadPool = ThreadUtils.newDaemonFixedThreadPool( + prev.getNumPartitions, + this.name) + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val part = split.asInstanceOf[ContinuousCoalesceRDDPartition] + + if (!part.writersInitialized) { + val rpcEnv = SparkEnv.get.rpcEnv + + // trigger lazy initialization + part.endpoint + val endpointRefs = readerEndpointNames.map { endpointName => + rpcEnv.setupEndpointRef(rpcEnv.address, endpointName) + } + + val runnables = prev.partitions.map { prevSplit => + new Runnable() { + override def run(): Unit = { + TaskContext.setTaskContext(context) + + val writer: ContinuousShuffleWriter = new RPCContinuousShuffleWriter( + prevSplit.index, outputPartitioner, endpointRefs.toArray) + + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) + while (!context.isInterrupted() && !context.isCompleted()) { + writer.write(prev.compute(prevSplit, context).asInstanceOf[Iterator[UnsafeRow]]) + // Note that current epoch is a non-inheritable thread local, so each writer thread + // can properly increment its own epoch without affecting the main task thread. + EpochTracker.incrementCurrentEpoch() + } + } + } + } + + context.addTaskCompletionListener { ctx => + threadPool.shutdownNow() + } + + part.writersInitialized = true + + runnables.foreach(threadPool.execute) + } + + part.reader.read() + } + + override def clearDependencies(): Unit = { + throw new IllegalStateException("Continuous RDDs cannot be checkpointed") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala new file mode 100644 index 0000000000000..73868d5967e90 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeInputPartitionReader} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, PartitionOffset} +import org.apache.spark.util.{NextIterator, ThreadUtils} + +class ContinuousDataSourceRDDPartition( + val index: Int, + val inputPartition: InputPartition[UnsafeRow]) + extends Partition with Serializable { + + // This is semantically a lazy val - it's initialized once the first time a call to + // ContinuousDataSourceRDD.compute() needs to access it, so it can be shared across + // all compute() calls for a partition. This ensures that one compute() picks up where the + // previous one ended. + // We don't make it actually a lazy val because it needs input which isn't available here. + // This will only be initialized on the executors. + private[continuous] var queueReader: ContinuousQueuedDataReader = _ +} + +/** + * The bottom-most RDD of a continuous processing read task. Wraps a [[ContinuousQueuedDataReader]] + * to read from the remote source, and polls that queue for incoming rows. + * + * Note that continuous processing calls compute() multiple times, and the same + * [[ContinuousQueuedDataReader]] instance will/must be shared between each call for the same split. + */ +class ContinuousDataSourceRDD( + sc: SparkContext, + dataQueueSize: Int, + epochPollIntervalMs: Long, + private val readerInputPartitions: Seq[InputPartition[UnsafeRow]]) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + readerInputPartitions.zipWithIndex.map { + case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) + }.toArray + } + + /** + * Initialize the shared reader for this partition if needed, then read rows from it until + * it returns null to signal the end of the epoch. + */ + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + // If attempt number isn't 0, this is a task retry, which we don't support. + if (context.attemptNumber() != 0) { + throw new ContinuousTaskRetryException() + } + + val readerForPartition = { + val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] + if (partition.queueReader == null) { + partition.queueReader = + new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs) + } + + partition.queueReader + } + + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = { + readerForPartition.next() match { + case null => + finished = true + null + case row => row + } + } + + override def close(): Unit = {} + } + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations() + } +} + +object ContinuousDataSourceRDD { + private[continuous] def getContinuousReader( + reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = { + reader match { + case r: ContinuousInputPartitionReader[UnsafeRow] => r + case wrapped: RowToUnsafeInputPartitionReader => + wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]] + case _ => + throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala deleted file mode 100644 index 06754f01657d3..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous - -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean - -import scala.collection.JavaConverters._ - -import org.apache.spark._ -import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} -import org.apache.spark.util.ThreadUtils - -class ContinuousDataSourceRDD( - sc: SparkContext, - sqlContext: SQLContext, - @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { - - private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize - private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs - - override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) - }.toArray - } - - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - // If attempt number isn't 0, this is a task retry, which we don't support. - if (context.attemptNumber() != 0) { - throw new ContinuousTaskRetryException() - } - - val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]] - .readerFactory.createDataReader() - - val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) - - // This queue contains two types of messages: - // * (null, null) representing an epoch boundary. - // * (row, off) containing a data row and its corresponding PartitionOffset. - val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize) - - val epochPollFailed = new AtomicBoolean(false) - val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( - s"epoch-poll--$coordinatorId--${context.partitionId()}") - val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) - epochPollExecutor.scheduleWithFixedDelay( - epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - - // Important sequencing - we must get start offset before the data reader thread begins - val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset - - val dataReaderFailed = new AtomicBoolean(false) - val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed) - dataReaderThread.setDaemon(true) - dataReaderThread.start() - - context.addTaskCompletionListener(_ => { - dataReaderThread.interrupt() - epochPollExecutor.shutdown() - }) - - val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get) - new Iterator[UnsafeRow] { - private val POLL_TIMEOUT_MS = 1000 - - private var currentEntry: (UnsafeRow, PartitionOffset) = _ - private var currentOffset: PartitionOffset = startOffset - private var currentEpoch = - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def hasNext(): Boolean = { - while (currentEntry == null) { - if (context.isInterrupted() || context.isCompleted()) { - currentEntry = (null, null) - } - if (dataReaderFailed.get()) { - throw new SparkException("data read failed", dataReaderThread.failureReason) - } - if (epochPollFailed.get()) { - throw new SparkException("epoch poll failed", epochPollRunnable.failureReason) - } - currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) - } - - currentEntry match { - // epoch boundary marker - case (null, null) => - epochEndpoint.send(ReportPartitionOffset( - context.partitionId(), - currentEpoch, - currentOffset)) - currentEpoch += 1 - currentEntry = null - false - // real row - case (_, offset) => - currentOffset = offset - true - } - } - - override def next(): UnsafeRow = { - if (currentEntry == null) throw new NoSuchElementException("No current row was set") - val r = currentEntry._1 - currentEntry = null - r - } - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations() - } -} - -case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset - -class EpochPollRunnable( - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread with Logging { - private[continuous] var failureReason: Throwable = _ - - private val epochEndpoint = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def run(): Unit = { - try { - val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch) - for (i <- currentEpoch to newEpoch - 1) { - queue.put((null, null)) - logDebug(s"Sent marker to start epoch ${i + 1}") - } - currentEpoch = newEpoch - } catch { - case t: Throwable => - failureReason = t - failedFlag.set(true) - throw t - } - } -} - -class DataReaderThread( - reader: DataReader[UnsafeRow], - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread( - s"continuous-reader--${context.partitionId()}--" + - s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") { - private[continuous] var failureReason: Throwable = _ - - override def run(): Unit = { - TaskContext.setTaskContext(context) - val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) - try { - while (!context.isInterrupted && !context.isCompleted()) { - if (!reader.next()) { - // Check again, since reader.next() might have blocked through an incoming interrupt. - if (!context.isInterrupted && !context.isCompleted()) { - throw new IllegalStateException( - "Continuous reader reported no elements! Reader should have blocked waiting.") - } else { - return - } - } - - queue.put((reader.get().copy(), baseReader.getOffset)) - } - } catch { - case _: InterruptedException if context.isInterrupted() => - // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. - - case t: Throwable => - failureReason = t - failedFlag.set(true) - // Don't rethrow the exception in this thread. It's not needed, and the default Spark - // exception handler will kill the executor. - } finally { - reader.close() - } - } -} - -object ContinuousDataSourceRDD { - private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { - reader match { - case r: ContinuousDataReader[UnsafeRow] => r - case wrapped: RowToUnsafeDataReader => - wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] - case _ => - throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 951d694355ec5..e991dbc81696d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -122,16 +122,7 @@ class ContinuousExecution( s"Batch $latestEpochId was committed without end epoch offsets!") } committedOffsets = nextOffsets.toStreamProgress(sources) - - // Get to an epoch ID that has definitely never been sent to a sink before. Since sink - // commit happens between offset log write and commit log write, this means an epoch ID - // which is not in the offset log. - val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { - throw new IllegalStateException( - s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + - s"an element.") - } - currentBatchId = latestOffsetEpoch + 1 + currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets @@ -199,7 +190,7 @@ class ContinuousExecution( triggerLogicalPlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) + val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) val reader = withSink.collect { case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r @@ -225,6 +216,9 @@ class ContinuousExecution( currentEpochCoordinatorId = epochCoordinatorId sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) + sparkSessionForQuery.sparkContext.setLocalProperty( + ContinuousExecution.EPOCH_INTERVAL_KEY, + trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = @@ -315,7 +309,10 @@ class ContinuousExecution( def commit(epoch: Long): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit") + synchronized { + // Record offsets before updating `committedOffsets` + recordTriggerOffsets(from = committedOffsets, to = availableOffsets) if (queryExecutionThread.isAlive) { commitLog.add(epoch) val offset = @@ -327,9 +324,14 @@ class ContinuousExecution( } } - if (minLogEntriesToMaintain < currentBatchId) { - offsetLog.purge(currentBatchId - minLogEntriesToMaintain) - commitLog.purge(currentBatchId - minLogEntriesToMaintain) + // Since currentBatchId increases independently in cp mode, the current committed epoch may + // be far behind currentBatchId. It is not safe to discard the metadata with thresholdBatchId + // computed based on currentBatchId. As minLogEntriesToMaintain is used to keep the minimum + // number of batches that must be retained and made recoverable, so we should keep the + // specified number of metadata that have been committed. + if (minLogEntriesToMaintain <= epoch) { + offsetLog.purge(epoch + 1 - minLogEntriesToMaintain) + commitLog.purge(epoch + 1 - minLogEntriesToMaintain) } awaitProgressLock.lock() @@ -365,9 +367,26 @@ class ContinuousExecution( } } } + + /** + * Stops the query execution thread to terminate the query. + */ + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state.set(TERMINATED) + if (queryExecutionThread.isAlive) { + // The query execution thread will clean itself up in the finally clause of runContinuous. + // We just need to interrupt the long running job. + queryExecutionThread.interrupt() + queryExecutionThread.join() + } + logInfo(s"Query $prettyIdString was stopped") + } } object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" + val EPOCH_INTERVAL_KEY = "__continuous_epoch_interval" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala new file mode 100644 index 0000000000000..8c74b8244d096 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.io.Closeable +import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset +import org.apache.spark.util.ThreadUtils + +/** + * A wrapper for a continuous processing data reader, including a reading queue and epoch markers. + * + * This will be instantiated once per partition - successive calls to compute() in the + * [[ContinuousDataSourceRDD]] will reuse the same reader. This is required to get continuity of + * offsets across epochs. Each compute() should call the next() method here until null is returned. + */ +class ContinuousQueuedDataReader( + partition: ContinuousDataSourceRDDPartition, + context: TaskContext, + dataQueueSize: Int, + epochPollIntervalMs: Long) extends Closeable { + private val reader = partition.inputPartition.createPartitionReader() + + // Important sequencing - we must get our starting point before the provider threads start running + private var currentOffset: PartitionOffset = + ContinuousDataSourceRDD.getContinuousReader(reader).getOffset + + /** + * The record types in the read buffer. + */ + sealed trait ContinuousRecord + case object EpochMarker extends ContinuousRecord + case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord + + private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize) + + private val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + + private val epochMarkerExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + s"epoch-poll--$coordinatorId--${context.partitionId()}") + private val epochMarkerGenerator = new EpochMarkerGenerator + epochMarkerExecutor.scheduleWithFixedDelay( + epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) + + private val dataReaderThread = new DataReaderThread + dataReaderThread.setDaemon(true) + dataReaderThread.start() + + context.addTaskCompletionListener(_ => { + this.close() + }) + + private def shouldStop() = { + context.isInterrupted() || context.isCompleted() + } + + /** + * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done. + * + * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch + * will call next() again to start getting rows. + */ + def next(): UnsafeRow = { + val POLL_TIMEOUT_MS = 1000 + var currentEntry: ContinuousRecord = null + + while (currentEntry == null) { + if (shouldStop()) { + // Force the epoch to end here. The writer will notice the context is interrupted + // or completed and not start a new one. This makes it possible to achieve clean + // shutdown of the streaming query. + // TODO: The obvious generalization of this logic to multiple stages won't work. It's + // invalid to send an epoch marker from the bottom of a task if all its child tasks + // haven't sent one. + currentEntry = EpochMarker + } else { + if (dataReaderThread.failureReason != null) { + throw new SparkException("Data read failed", dataReaderThread.failureReason) + } + if (epochMarkerGenerator.failureReason != null) { + throw new SparkException( + "Epoch marker generation failed", + epochMarkerGenerator.failureReason) + } + currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } + } + + currentEntry match { + case EpochMarker => + epochCoordEndpoint.send(ReportPartitionOffset( + partition.index, EpochTracker.getCurrentEpoch.get, currentOffset)) + null + case ContinuousRow(row, offset) => + currentOffset = offset + row + } + } + + override def close(): Unit = { + dataReaderThread.interrupt() + epochMarkerExecutor.shutdown() + } + + /** + * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when + * a new row arrives to the [[InputPartitionReader]]. + */ + class DataReaderThread extends Thread( + s"continuous-reader--${context.partitionId()}--" + + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + override def run(): Unit = { + TaskContext.setTaskContext(context) + val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader) + try { + while (!shouldStop()) { + if (!reader.next()) { + // Check again, since reader.next() might have blocked through an incoming interrupt. + if (!shouldStop()) { + throw new IllegalStateException( + "Continuous reader reported no elements! Reader should have blocked waiting.") + } else { + return + } + } + + queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset)) + } + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. + logInfo(s"shutting down interrupted data reader thread $getName") + + case NonFatal(t) => + failureReason = t + logWarning("data reader thread failed", t) + // If we throw from this thread, we may kill the executor. Let the parent thread handle + // it. + + case t: Throwable => + failureReason = t + throw t + } finally { + reader.close() + } + } + } + + /** + * The epoch marker component of [[ContinuousQueuedDataReader]]. Populates the queue with + * EpochMarker when a new epoch marker arrives. + */ + class EpochMarkerGenerator extends Runnable with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + // Note that this is *not* the same as the currentEpoch in [[ContinuousWriteRDD]]! That + // field represents the epoch wrt the data being processed. The currentEpoch here is just a + // counter to ensure we send the appropriate number of markers if we fall behind the driver. + private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + override def run(): Unit = { + try { + val newEpoch = epochCoordEndpoint.askSync[Long](GetCurrentEpoch) + // It's possible to fall more than 1 epoch behind if a GetCurrentEpoch RPC ends up taking + // a while. We catch up by injecting enough epoch markers immediately to catch up. This will + // result in some epochs being empty for this partition, but that's fine. + for (i <- currentEpoch to newEpoch - 1) { + queue.put(EpochMarker) + logDebug(s"Sent marker to start epoch ${i + 1}") + } + currentEpoch = newEpoch + } catch { + case t: Throwable => + failureReason = t + throw t + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 2f0de2612c150..516a563bdcc7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeM import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( @@ -67,7 +67,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def getStartOffset(): Offset = offset - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + override def planInputPartitions(): java.util.List[InputPartition[Row]] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => @@ -85,13 +85,13 @@ class RateStreamContinuousReader(options: DataSourceOptions) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamContinuousDataReaderFactory( + RateStreamContinuousInputPartition( start.value, start.runTimeMs, i, numPartitions, perPartitionRate) - .asInstanceOf[DataReaderFactory[Row]] + .asInstanceOf[InputPartition[Row]] }.asJava } @@ -113,19 +113,19 @@ class RateStreamContinuousReader(options: DataSourceOptions) } -case class RateStreamContinuousDataReaderFactory( +case class RateStreamContinuousInputPartition( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousDataReaderFactory[Row] { + extends ContinuousInputPartition[Row] { - override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = { + override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[Row] = { val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] require(rateStreamOffset.partition == partitionIndex, s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") - new RateStreamContinuousDataReader( + new RateStreamContinuousInputPartitionReader( rateStreamOffset.currentValue, rateStreamOffset.currentTimeMs, partitionIndex, @@ -133,18 +133,18 @@ case class RateStreamContinuousDataReaderFactory( rowsPerSecond) } - override def createDataReader(): DataReader[Row] = - new RateStreamContinuousDataReader( + override def createPartitionReader(): InputPartitionReader[Row] = + new RateStreamContinuousInputPartitionReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } -class RateStreamContinuousDataReader( +class RateStreamContinuousInputPartitionReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousDataReader[Row] { + extends ContinuousInputPartitionReader[Row] { private var nextReadTime: Long = startTimeMs private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala new file mode 100644 index 0000000000000..76f3f5baa8d56 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.{Partition, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.util.Utils + +/** + * The RDD writing to a sink in continuous processing. + * + * Within each task, we repeatedly call prev.compute(). Each resulting iterator contains the data + * to be written for one epoch, which we commit and forward to the driver. + * + * We keep repeating prev.compute() and writing new epochs until the query is shut down. + */ +class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow]) + extends RDD[Unit](prev) { + + override val partitioner = prev.partitioner + + override def getPartitions: Array[Partition] = prev.partitions + + override def compute(split: Partition, context: TaskContext): Iterator[Unit] = { + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + SparkEnv.get) + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) + + while (!context.isInterrupted() && !context.isCompleted()) { + var dataWriter: DataWriter[InternalRow] = null + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + val dataIterator = prev.compute(split, context) + dataWriter = writeTask.createDataWriter( + context.partitionId(), + context.taskAttemptId(), + EpochTracker.getCurrentEpoch.get) + while (dataIterator.hasNext) { + dataWriter.write(dataIterator.next()) + } + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.") + val msg = dataWriter.commit() + epochCoordinator.send( + CommitPartitionEpoch( + context.partitionId(), + EpochTracker.getCurrentEpoch.get, + msg) + ) + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch ${EpochTracker.getCurrentEpoch.get} committed.") + EpochTracker.incrementCurrentEpoch() + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so compute() will stop executing at this point. + logError(s"Writer for partition ${context.partitionId()} is aborting.") + if (dataWriter != null) dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } + + Iterator() + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index cc6808065c0cd..8877ebeb26735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -137,30 +137,71 @@ private[continuous] class EpochCoordinator( private val partitionOffsets = mutable.Map[(Long, Int), PartitionOffset]() + private var lastCommittedEpoch = startEpoch - 1 + // Remembers epochs that have to wait for previous epochs to be committed first. + private val epochsWaitingToBeCommitted = mutable.HashSet.empty[Long] + private def resolveCommitsAtEpoch(epoch: Long) = { - val thisEpochCommits = - partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + val thisEpochCommits = findPartitionCommitsForEpoch(epoch) val nextEpochOffsets = partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochCommits.size == numWriterPartitions && nextEpochOffsets.size == numReaderPartitions) { - logDebug(s"Epoch $epoch has received commits from all partitions. Committing globally.") - // Sequencing is important here. We must commit to the writer before recording the commit - // in the query, or we will end up dropping the commit if we restart in the middle. - writer.commit(epoch, thisEpochCommits.toArray) - query.commit(epoch) - - // Cleanup state from before this epoch, now that we know all partitions are forever past it. - for (k <- partitionCommits.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) - } - for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionOffsets.remove(k) + + // Check that last committed epoch is the previous one for sequencing of committed epochs. + // If not, add the epoch being currently processed to epochs waiting to be committed, + // otherwise commit it. + if (lastCommittedEpoch != epoch - 1) { + logDebug(s"Epoch $epoch has received commits from all partitions " + + s"and is waiting for epoch ${epoch - 1} to be committed first.") + epochsWaitingToBeCommitted.add(epoch) + } else { + commitEpoch(epoch, thisEpochCommits) + lastCommittedEpoch = epoch + + // Commit subsequent epochs that are waiting to be committed. + var nextEpoch = lastCommittedEpoch + 1 + while (epochsWaitingToBeCommitted.contains(nextEpoch)) { + val nextEpochCommits = findPartitionCommitsForEpoch(nextEpoch) + commitEpoch(nextEpoch, nextEpochCommits) + + epochsWaitingToBeCommitted.remove(nextEpoch) + lastCommittedEpoch = nextEpoch + nextEpoch += 1 + } + + // Cleanup state from before last committed epoch, + // now that we know all partitions are forever past it. + for (k <- partitionCommits.keys.filter { case (e, _) => e < lastCommittedEpoch }) { + partitionCommits.remove(k) + } + for (k <- partitionOffsets.keys.filter { case (e, _) => e < lastCommittedEpoch }) { + partitionOffsets.remove(k) + } } } } + /** + * Collect per-partition commits for an epoch. + */ + private def findPartitionCommitsForEpoch(epoch: Long): Iterable[WriterCommitMessage] = { + partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + } + + /** + * Commit epoch to the offset log. + */ + private def commitEpoch(epoch: Long, messages: Iterable[WriterCommitMessage]): Unit = { + logDebug(s"Epoch $epoch has received commits from all partitions " + + s"and is ready to be committed. Committing epoch $epoch.") + // Sequencing is important here. We must commit to the writer before recording the commit + // in the query, or we will end up dropping the commit if we restart in the middle. + writer.commit(epoch, messages.toArray) + query.commit(epoch) + } + override def receive: PartialFunction[Any, Unit] = { // If we just drop these messages, we won't do any writes to the query. The lame duck tasks // won't shed errors or anything. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala new file mode 100644 index 0000000000000..bc0ae428d4521 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +/** + * Tracks the current continuous processing epoch within a task. Call + * EpochTracker.getCurrentEpoch to get the current epoch. + */ +object EpochTracker { + // The current epoch. Note that this is a shared reference; ContinuousWriteRDD.compute() will + // update the underlying AtomicLong as it finishes epochs. Other code should only read the value. + private val currentEpoch: ThreadLocal[AtomicLong] = new ThreadLocal[AtomicLong] { + override def initialValue() = new AtomicLong(-1) + } + + /** + * Get the current epoch for the current task, or None if the task has no current epoch. + */ + def getCurrentEpoch: Option[Long] = { + currentEpoch.get().get() match { + case n if n < 0 => None + case e => Some(e) + } + } + + /** + * Increment the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]] + * between epochs. + */ + def incrementCurrentEpoch(): Unit = { + currentEpoch.get().incrementAndGet() + } + + /** + * Initialize the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]] + * at the beginning of a task. + */ + def initializeCurrentEpoch(startEpoch: Long): Unit = { + currentEpoch.get().set(startEpoch) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala new file mode 100644 index 0000000000000..943c731a70529 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter + +/** + * The logical plan for writing data in a continuous stream. + */ +case class WriteToContinuousDataSource( + writer: StreamWriter, query: LogicalPlan) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala new file mode 100644 index 0000000000000..e0af3a2f1b85d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory} +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.util.Utils + +/** + * The physical plan for writing data into a continuous processing [[StreamWriter]]. + */ +case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan) + extends SparkPlan with Logging { + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writerFactory = writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + } + + val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) + + logInfo(s"Start processing data source writer: $writer. " + + s"The input RDD has ${rdd.partitions.length} partitions.") + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) + .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) + + try { + // Force the RDD to run so continuous processing starts; no data is actually being collected + // to the driver, as ContinuousWriteRDD outputs nothing. + rdd.collect() + } catch { + case _: InterruptedException => + // Interruption is how continuous queries are ended, so accept and ignore the exception. + case cause: Throwable => + cause match { + // Do not wrap interruption exceptions that will be handled by streaming specially. + case _ if StreamExecution.isInterruptionException(cause) => throw cause + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } + } + + sparkContext.emptyRDD + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala new file mode 100644 index 0000000000000..518223f3cd008 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.UUID + +import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcAddress +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.NextIterator + +case class ContinuousShuffleReadPartition( + index: Int, + endpointName: String, + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long) + extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(endpointName, receiver) + + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } +} + +/** + * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their + * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks + * poll from their receiver until an epoch marker is sent. + * + * @param sc the RDD context + * @param numPartitions the number of read partitions for this RDD + * @param queueSize the size of the row buffers to use + * @param numShuffleWriters the number of continuous shuffle writers feeding into this RDD + * @param epochIntervalMs the checkpoint interval of the streaming query + */ +class ContinuousShuffleReadRDD( + sc: SparkContext, + numPartitions: Int, + queueSize: Int = 1024, + numShuffleWriters: Int = 1, + epochIntervalMs: Long = 1000, + val endpointNames: Seq[String] = Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}")) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousShuffleReadPartition( + partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, epochIntervalMs) + }.toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + split.asInstanceOf[ContinuousShuffleReadPartition].reader.read() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala new file mode 100644 index 0000000000000..42631c90ebc55 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for reading from a continuous processing shuffle. + */ +trait ContinuousShuffleReader { + /** + * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting + * for new rows to arrive, and end the iterator once they've received epoch markers from all + * shuffle writers. + */ + def read(): Iterator[UnsafeRow] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..47b1f78b24505 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for writing to a continuous processing shuffle. + */ +trait ContinuousShuffleWriter { + def write(epoch: Iterator[UnsafeRow]): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala new file mode 100644 index 0000000000000..502ae0d4822e8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.NextIterator + +/** + * Messages for the RPCContinuousShuffleReader endpoint. Either an incoming row or an epoch marker. + * + * Each message comes tagged with writerId, identifying which writer the message is coming + * from. The receiver will only begin the next epoch once all writers have sent an epoch + * marker ending the current epoch. + */ +private[shuffle] sealed trait RPCContinuousShuffleMessage extends Serializable { + def writerId: Int +} +private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) + extends RPCContinuousShuffleMessage +private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContinuousShuffleMessage + +/** + * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle + * writers will send rows here, with continuous shuffle readers polling for new rows as needed. + * + * TODO: Support multiple source tasks. We need to output a single epoch marker once all + * source tasks have sent one. + */ +private[continuous] class RPCContinuousShuffleReader( + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long, + override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { + // Note that this queue will be drained from the main task thread and populated in the RPC + // response thread. + private val queues = Array.fill(numShuffleWriters) { + new ArrayBlockingQueue[RPCContinuousShuffleMessage](queueSize) + } + + // Exposed for testing to determine if the endpoint gets stopped on task end. + private[shuffle] val stopped = new AtomicBoolean(false) + + override def onStop(): Unit = { + stopped.set(true) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case r: RPCContinuousShuffleMessage => + // Note that this will block a thread the shared RPC handler pool! + // The TCP based shuffle handler (SPARK-24541) will avoid this problem. + queues(r.writerId).put(r) + context.reply(()) + } + + override def read(): Iterator[UnsafeRow] = { + new NextIterator[UnsafeRow] { + // An array of flags for whether each writer ID has gotten an epoch marker. + private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false) + + private val executor = Executors.newFixedThreadPool(numShuffleWriters) + private val completion = new ExecutorCompletionService[RPCContinuousShuffleMessage](executor) + + private def completionTask(writerId: Int) = new Callable[RPCContinuousShuffleMessage] { + override def call(): RPCContinuousShuffleMessage = queues(writerId).take() + } + + // Initialize by submitting tasks to read the first row from each writer. + (0 until numShuffleWriters).foreach(writerId => completion.submit(completionTask(writerId))) + + /** + * In each call to getNext(), we pull the next row available in the completion queue, and then + * submit another task to read the next row from the writer which returned it. + * + * When a writer sends an epoch marker, we note that it's finished and don't submit another + * task for it in this epoch. The iterator is over once all writers have sent an epoch marker. + */ + override def getNext(): UnsafeRow = { + var nextRow: UnsafeRow = null + while (!finished && nextRow == null) { + completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match { + case null => + // Try again if the poll didn't wait long enough to get a real result. + // But we should be getting at least an epoch marker every checkpoint interval. + val writerIdsUncommitted = writerEpochMarkersReceived.zipWithIndex.collect { + case (flag, idx) if !flag => idx + } + logWarning( + s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + + s"for writers ${writerIdsUncommitted.mkString(",")} to send epoch markers.") + + // The completion service guarantees this future will be available immediately. + case future => future.get() match { + case ReceiverRow(writerId, r) => + // Start reading the next element in the queue we just took from. + completion.submit(completionTask(writerId)) + nextRow = r + case ReceiverEpochMarker(writerId) => + // Don't read any more from this queue. If all the writers have sent epoch markers, + // the epoch is over; otherwise we need to loop again to poll from the remaining + // writers. + writerEpochMarkersReceived(writerId) = true + if (writerEpochMarkersReceived.forall(_ == true)) { + finished = true + } + } + } + } + + nextRow + } + + override def close(): Unit = { + executor.shutdownNow() + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..1c6f3ddb395e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import scala.concurrent.Future +import scala.concurrent.duration.Duration + +import org.apache.spark.Partitioner +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.ThreadUtils + +/** + * A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances. + * + * @param writerId The partition ID of this writer. + * @param outputPartitioner The partitioner on the reader side of the shuffle. + * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by + * partition ID within outputPartitioner. + */ +class RPCContinuousShuffleWriter( + writerId: Int, + outputPartitioner: Partitioner, + endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter { + + if (outputPartitioner.numPartitions != 1) { + throw new IllegalArgumentException("multiple readers not yet supported") + } + + if (outputPartitioner.numPartitions != endpoints.length) { + throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " + + s"not match endpoint count ${endpoints.length}") + } + + def write(epoch: Iterator[UnsafeRow]): Unit = { + while (epoch.hasNext) { + val row = epoch.next() + endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row)) + } + + val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq + implicit val ec = ThreadUtils.sameThread + ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 628923d367ce7..b137f98045c5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,23 +24,21 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.streaming.{OutputMode, Trigger} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils - object MemoryStream { protected val currentBlockId = new AtomicInteger(0) protected val memoryStreamId = new AtomicInteger(0) @@ -141,7 +139,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) if (endOffset.offset == -1) null else endOffset } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -158,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]] + new MemoryStreamInputPartition(block).asInstanceOf[InputPartition[UnsafeRow]] }.asJava } } @@ -203,10 +201,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } -class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) - extends DataReaderFactory[UnsafeRow] { - override def createDataReader(): DataReader[UnsafeRow] = { - new DataReader[UnsafeRow] { +class MemoryStreamInputPartition(records: Array[UnsafeRow]) + extends InputPartition[UnsafeRow] { + override def createPartitionReader(): InputPartitionReader[UnsafeRow] = { + new InputPartitionReader[UnsafeRow] { private var currentIndex = -1 override def next(): Boolean = { @@ -222,11 +220,20 @@ class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) } } +/** A common trait for MemorySinks with methods used for testing */ +trait MemorySinkBase extends BaseStreamingSink { + def allData: Seq[Row] + def latestBatchData: Seq[Row] + def dataSinceBatch(sinceBatchId: Long): Seq[Row] + def latestBatchId: Option[Long] +} + /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink + with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -236,7 +243,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { - batches.map(_.data).flatten + batches.flatMap(_.data) } def latestBatchId: Option[Long] = synchronized { @@ -245,6 +252,10 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { @@ -294,7 +305,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) - private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index c28919b8b729b..0bf90b8063326 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -23,19 +23,20 @@ import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ +import scala.collection.SortedMap import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils @@ -44,13 +45,12 @@ import org.apache.spark.util.RpcUtils * * ContinuousMemoryStream maintains a list of records for each partition. addData() will * distribute records evenly-ish across partitions. * * RecordEndpoint is set up as an endpoint for executor-side - * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified - * offset within the list, or null if that offset doesn't yet have a record. + * ContinuousMemoryStreamInputPartitionReader instances to poll. It returns the record at + * the specified offset within the list, or null if that offset doesn't yet have a record. */ -class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { private implicit val formats = Serialization.formats(NoTypeHints) - private val NUM_PARTITIONS = 2 protected val logicalPlan = StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) @@ -58,7 +58,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) // ContinuousReader implementation @GuardedBy("this") - private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A]) + private val records = Seq.fill(numPartitions)(new ListBuffer[A]) @GuardedBy("this") private var startOffset: ContinuousMemoryStreamOffset = _ @@ -69,17 +69,17 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def addData(data: TraversableOnce[A]): Offset = synchronized { // Distribute data evenly among partition lists. data.toSeq.zipWithIndex.map { - case (item, index) => records(index % NUM_PARTITIONS) += item + case (item, index) => records(index % numPartitions) += item } // The new target offset is the offset where all records in all partitions have been processed. - ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap) + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) } override def setStartOffset(start: Optional[Offset]): Unit = synchronized { // Inferred initial offset is position 0 in each partition. startOffset = start.orElse { - ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) }.asInstanceOf[ContinuousMemoryStreamOffset] } @@ -99,7 +99,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) ) } - override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = { + override def planInputPartitions(): ju.List[InputPartition[Row]] = { synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = @@ -107,8 +107,8 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) startOffset.partitionNums.map { case (part, index) => - new ContinuousMemoryStreamDataReaderFactory( - endpointName, part, index): DataReaderFactory[Row] + new ContinuousMemoryStreamInputPartition( + endpointName, part, index): InputPartition[Row] }.toList.asJava } } @@ -152,28 +152,31 @@ object ContinuousMemoryStream { def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) + + def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1) } /** - * Data reader factory for continuous memory stream. + * An input partition for continuous memory stream. */ -class ContinuousMemoryStreamDataReaderFactory( +class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, - startOffset: Int) extends DataReaderFactory[Row] { - override def createDataReader: ContinuousMemoryStreamDataReader = - new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset) + startOffset: Int) extends InputPartition[Row] { + override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader = + new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset) } /** - * Data reader for continuous memory stream. + * An input partition reader for continuous memory stream. * * Polls the driver endpoint for new records. */ -class ContinuousMemoryStreamDataReader( +class ContinuousMemoryStreamInputPartitionReader( driverEndpointName: String, partition: Int, - startOffset: Int) extends ContinuousDataReader[Row] { + startOffset: Int) extends ContinuousInputPartitionReader[Row] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, SparkEnv.get.conf, @@ -182,12 +185,19 @@ class ContinuousMemoryStreamDataReader( private var currentOffset = startOffset private var current: Option[Row] = None + // Defense-in-depth against failing to propagate the task context. Since it's not inheritable, + // we have to do a bit of error prone work to get it into every thread used by continuous + // processing. We hope that some unit test will end up instantiating a continuous memory stream + // in such cases. + if (TaskContext.get() == null) { + throw new IllegalStateException("Task context was not set!") + } + override def next(): Boolean = { - current = None + current = getRecord while (current.isEmpty) { Thread.sleep(10) - current = endpoint.askSync[Option[Row]]( - GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) + current = getRecord } currentOffset += 1 true @@ -199,6 +209,10 @@ class ContinuousMemoryStreamDataReader( override def getOffset: ContinuousMemoryStreamPartitionOffset = ContinuousMemoryStreamPartitionOffset(partition, currentOffset) + + private def getRecord: Option[Row] = + endpoint.askSync[Option[Row]]( + GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) } case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala new file mode 100644 index 0000000000000..03c567c58d46a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.api.python.PythonException +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.streaming.DataStreamWriter + +class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: ExpressionEncoder[T]) + extends Sink { + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + val resolvedEncoder = encoder.resolveAndBind( + data.logicalPlan.output, + data.sparkSession.sessionState.analyzer) + val rdd = data.queryExecution.toRdd.map[T](resolvedEncoder.fromRow)(encoder.clsTag) + val ds = data.sparkSession.createDataset(rdd)(encoder) + batchWriter(ds, batchId) + } + + override def toString(): String = "ForeachBatchSink" +} + + +/** + * Interface that is meant to be extended by Python classes via Py4J. + * Py4J allows Python classes to implement Java interfaces so that the JVM can call back + * Python objects. In this case, this allows the user-defined Python `foreachBatch` function + * to be called from JVM when the query is active. + * */ +trait PythonForeachBatchFunction { + /** Call the Python implementation of this function */ + def call(batchDF: DataFrame, batchId: Long): Unit +} + +object PythonForeachBatchHelper { + def callForeachBatch(dsw: DataStreamWriter[Row], pythonFunc: PythonForeachBatchFunction): Unit = { + dsw.foreachBatch(pythonFunc.call _) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index df5d69d57e36f..bc9b6d93ce7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -31,9 +33,14 @@ import org.apache.spark.sql.types.StructType * [[ForeachWriter]]. * * @param writer The [[ForeachWriter]] to process all data. + * @param converter An object to convert internal rows to target type T. Either it can be + * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { +case class ForeachWriterProvider[T]( + writer: ForeachWriter[T], + converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { + override def createStreamWriter( queryId: String, schema: StructType, @@ -44,10 +51,16 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { - val encoder = encoderFor[T].resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - ForeachWriterFactory(writer, encoder) + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) } override def toString: String = "ForeachSink" @@ -55,29 +68,44 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S } } -case class ForeachWriterFactory[T: Encoder]( +object ForeachWriterProvider { + def apply[T]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { + writer match { + case pythonWriter: PythonForeachWriter => + new ForeachWriterProvider[UnsafeRow]( + pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) + case _ => + new ForeachWriterProvider[T](writer, Left(encoder)) + } + } +} + +case class ForeachWriterFactory[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]) + rowConverter: InternalRow => T) extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): ForeachDataWriter[T] = { - new ForeachDataWriter(writer, encoder, partitionId, epochId) + new ForeachDataWriter(writer, rowConverter, partitionId, epochId) } } /** * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * * @param writer The [[ForeachWriter]] to process all data. - * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]] * @param partitionId * @param epochId * @tparam T The type expected by the writer. */ -class ForeachDataWriter[T : Encoder]( +class ForeachDataWriter[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T], + rowConverter: InternalRow => T, partitionId: Int, epochId: Long) extends DataWriter[InternalRow] { @@ -89,7 +117,7 @@ class ForeachDataWriter[T : Encoder]( if (!opened) return try { - writer.process(encoder.fromRow(record)) + writer.process(rowConverter(record)) } catch { case t: Throwable => writer.close(t) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index e07355aa37dba..b501d90c81f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat case object PackedRowWriterFactory extends DataWriterFactory[Row] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { new PackedRowDataWriter() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index 6cf8520fc544f..b393c48baee8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -134,7 +134,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: LongOffset(json.toLong) } - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + override def planInputPartitions(): java.util.List[InputPartition[Row]] = { val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") @@ -167,9 +167,9 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: } (0 until numPartitions).map { p => - new RateStreamMicroBatchDataReaderFactory( + new RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : DataReaderFactory[Row] + : InputPartition[Row] }.toList.asJava } @@ -177,31 +177,32 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: override def stop(): Unit = {} - override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } -class RateStreamMicroBatchDataReaderFactory( +class RateStreamMicroBatchInputPartition( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReaderFactory[Row] { + relativeMsPerValue: Double) extends InputPartition[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( - partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + override def createPartitionReader(): InputPartitionReader[Row] = + new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd, + localStartTimeMs, relativeMsPerValue) } -class RateStreamMicroBatchDataReader( +class RateStreamMicroBatchInputPartitionReader( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReader[Row] { - private var count = 0 + relativeMsPerValue: Double) extends InputPartitionReader[Row] { + private var count: Long = 0 override def next(): Boolean = { rangeStart + partitionId + numPartitions * count < rangeEnd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 5f58246083bb2..f2a35a90af24a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -27,8 +27,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} -import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -39,7 +40,7 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { +class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging { override def createStreamWriter( queryId: String, schema: StructType, @@ -67,6 +68,10 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { @@ -96,7 +101,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { case _ => throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") + s"Output mode $outputMode is not supported by MemorySinkV2") } } else { logDebug(s"Skipping already committed batch: $batchId") @@ -107,7 +112,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { batches.clear() } - override def toString(): String = "MemorySink" + override def toString(): String = "MemorySinkV2" } case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} @@ -149,7 +154,7 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { new MemoryDataWriter(partitionId, outputMode) } @@ -175,10 +180,10 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) /** - * Used to query the data that has been written into a [[MemorySink]]. + * Used to query the data that has been written into a [[MemorySinkV2]]. */ case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { - private val sizePerRow = output.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(output) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 5aae46b463398..91e3b7179c34a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -22,6 +22,7 @@ import java.net.Socket import java.sql.Timestamp import java.text.SimpleDateFormat import java.util.{Calendar, List => JList, Locale, Optional} +import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ @@ -33,7 +34,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.LongOffset import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} @@ -76,7 +77,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR @GuardedBy("this") private var lastOffsetCommitted: LongOffset = LongOffset(-1L) - initialize() + private val initialized: AtomicBoolean = new AtomicBoolean(false) /** This method is only used for unit test */ private[sources] def getCurrentOffset(): LongOffset = synchronized { @@ -140,7 +141,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { assert(startOffset != null && endOffset != null, "start offset and end offset should already be set before create read tasks.") @@ -149,6 +150,10 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { + if (initialized.compareAndSet(false, true)) { + initialize() + } + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 batches.slice(sliceStart, sliceEnd) @@ -165,21 +170,22 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR (0 until numPartitions).map { i => val slice = slices(i) - new DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new DataReader[Row] { - private var currentIdx = -1 + new InputPartition[Row] { + override def createPartitionReader(): InputPartitionReader[Row] = + new InputPartitionReader[Row] { + private var currentIdx = -1 + + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } - override def next(): Boolean = { - currentIdx += 1 - currentIdx < slice.size - } + override def get(): Row = { + Row(slice(currentIdx)._1, slice(currentIdx)._2) + } - override def get(): Row = { - Row(slice(currentIdx)._1, slice(currentIdx)._2) + override def close(): Unit = {} } - - override def close(): Unit = {} - } } }.toList.asJava } @@ -214,7 +220,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def toString: String = s"TextSocket[host: $host, port: $port]" + override def toString: String = s"TextSocketV2[host: $host, port: $port]" } class TextSocketSourceProvider extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index df722b953228b..118c82aa75e68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ -import java.nio.channels.ClosedChannelException import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.util.Random import scala.util.control.NonFatal import com.google.common.io.ByteStreams @@ -280,38 +278,49 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit if (loadedCurrentVersionMap.isDefined) { return loadedCurrentVersionMap.get } - val snapshotCurrentVersionMap = readSnapshotFile(version) - if (snapshotCurrentVersionMap.isDefined) { - synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } - return snapshotCurrentVersionMap.get - } - // Find the most recent map before this version that we can. - // [SPARK-22305] This must be done iteratively to avoid stack overflow. - var lastAvailableVersion = version - var lastAvailableMap: Option[MapType] = None - while (lastAvailableMap.isEmpty) { - lastAvailableVersion -= 1 + logWarning(s"The state for version $version doesn't exist in loadedMaps. " + + "Reading snapshot file and delta files if needed..." + + "Note that this is normal for the first batch of starting query.") - if (lastAvailableVersion <= 0) { - // Use an empty map for versions 0 or less. - lastAvailableMap = Some(new MapType) - } else { - lastAvailableMap = - synchronized { loadedMaps.get(lastAvailableVersion) } - .orElse(readSnapshotFile(lastAvailableVersion)) + val (result, elapsedMs) = Utils.timeTakenMs { + val snapshotCurrentVersionMap = readSnapshotFile(version) + if (snapshotCurrentVersionMap.isDefined) { + synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } + return snapshotCurrentVersionMap.get + } + + // Find the most recent map before this version that we can. + // [SPARK-22305] This must be done iteratively to avoid stack overflow. + var lastAvailableVersion = version + var lastAvailableMap: Option[MapType] = None + while (lastAvailableMap.isEmpty) { + lastAvailableVersion -= 1 + + if (lastAvailableVersion <= 0) { + // Use an empty map for versions 0 or less. + lastAvailableMap = Some(new MapType) + } else { + lastAvailableMap = + synchronized { loadedMaps.get(lastAvailableVersion) } + .orElse(readSnapshotFile(lastAvailableVersion)) + } + } + + // Load all the deltas from the version after the last available one up to the target version. + // The last available version is the one with a full snapshot, so it doesn't need deltas. + val resultMap = new MapType(lastAvailableMap.get) + for (deltaVersion <- lastAvailableVersion + 1 to version) { + updateFromDeltaFile(deltaVersion, resultMap) } - } - // Load all the deltas from the version after the last available one up to the target version. - // The last available version is the one with a full snapshot, so it doesn't need deltas. - val resultMap = new MapType(lastAvailableMap.get) - for (deltaVersion <- lastAvailableVersion + 1 to version) { - updateFromDeltaFile(deltaVersion, resultMap) + synchronized { loadedMaps.put(version, resultMap) } + resultMap } - synchronized { loadedMaps.put(version, resultMap) } - resultMap + logDebug(s"Loading state for $version takes $elapsedMs ms.") + + result } private def writeUpdateToDeltaFile( @@ -490,7 +499,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Perform a snapshot of the store to allow delta files to be consolidated */ private def doSnapshot(): Unit = { try { - val files = fetchFiles() + val (files, e1) = Utils.timeTakenMs(fetchFiles()) + logDebug(s"fetchFiles() took $e1 ms.") + if (files.nonEmpty) { val lastVersion = files.last.version val deltaFilesForLastVersion = @@ -498,7 +509,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit synchronized { loadedMaps.get(lastVersion) } match { case Some(map) => if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) { - writeSnapshotFile(lastVersion, map) + val (_, e2) = Utils.timeTakenMs(writeSnapshotFile(lastVersion, map)) + logDebug(s"writeSnapshotFile() took $e2 ms.") } case None => // The last map is not loaded, probably some other instance is in charge @@ -517,7 +529,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit */ private[state] def cleanup(): Unit = { try { - val files = fetchFiles() + val (files, e1) = Utils.timeTakenMs(fetchFiles()) + logDebug(s"fetchFiles() took $e1 ms.") + if (files.nonEmpty) { val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain if (earliestVersionToRetain > 0) { @@ -527,9 +541,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit mapsToRemove.foreach(loadedMaps.remove) } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) - filesToDelete.foreach { f => - fm.delete(f.path) + val (_, e2) = Utils.timeTakenMs { + filesToDelete.foreach { f => + fm.delete(f.path) + } } + logDebug(s"deleting files took $e2 ms.") logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + filesToDelete.mkString(", ")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 01d8e75980993..3f11b8f79943c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.streaming.continuous.EpochTracker import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -71,8 +72,15 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( StateStoreId(checkpointLocation, operatorId, partition.index), queryRunId) + // If we're in continuous processing mode, we should get the store version for the current + // epoch rather than the one at planning time. + val currentVersion = EpochTracker.getCurrentEpoch match { + case None => storeVersion + case Some(value) => value + } + store = StateStore.get( - storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeProviderId, keySchema, valueSchema, indexOrdinal, currentVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index b9b07a2e688f9..6759fb42b4052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ -import org.apache.spark.util.{CompletionIterator, NextIterator} +import org.apache.spark.util.{CompletionIterator, NextIterator, Utils} /** Used to identify the state store for a given operator. */ @@ -97,12 +97,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => } /** Records the duration of running `body` for the next query progress update. */ - protected def timeTakenMs(body: => Unit): Long = { - val startTime = System.nanoTime() - val result = body - val endTime = System.nanoTime() - math.max(NANOSECONDS.toMillis(endTime - startTime), 0) - } + protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2 /** * Set the SQL metrics related to the state store. @@ -126,6 +121,12 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => name -> SQLMetrics.createTimingMetric(sparkContext, desc) }.toMap } + + /** + * Should the MicroBatchExecution run another batch based on this stateful operator and the + * current updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = false } /** An operator that supports watermark. */ @@ -340,37 +341,35 @@ case class StateStoreSaveExec( // Update and output modified rows from the StateStore. case Some(Update) => - val updatesStartTimeNs = System.nanoTime - - new Iterator[InternalRow] { - + new NextIterator[InternalRow] { // Filter late date using watermark if specified private[this] val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) case None => iter } + private val updatesStartTimeNs = System.nanoTime - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) - - // Remove old aggregates if watermark specified - allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } - commitTimeMs += timeTakenMs { store.commit() } - setStoreMetrics(store) - false + override protected def getNext(): InternalRow = { + if (baseIterator.hasNext) { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key, row) + numOutputRows += 1 + numUpdatedStateRows += 1 + row } else { - true + finished = true + null } } - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) - numOutputRows += 1 - numUpdatedStateRows += 1 - row + override protected def close(): Unit = { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + + // Remove old aggregates if watermark specified + allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) } } @@ -390,6 +389,12 @@ case class StateStoreSaveExec( ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + (outputMode.contains(Append) || outputMode.contains(Update)) && + eventTimeWatermark.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } /** Physical operator for executing streaming Deduplicate. */ @@ -456,6 +461,10 @@ case class StreamingDeduplicateExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } object StreamingDeduplicateExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 582528777f90e..a7a24ac3641b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -58,21 +58,21 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L _content ++= new RunningExecutionTable( parent, s"Running Queries (${running.size})", currentTime, - running.sortBy(_.submissionTime).reverse).toNodeSeq + running.sortBy(_.submissionTime).reverse).toNodeSeq(request) } if (completed.nonEmpty) { _content ++= new CompletedExecutionTable( parent, s"Completed Queries (${completed.size})", currentTime, - completed.sortBy(_.submissionTime).reverse).toNodeSeq + completed.sortBy(_.submissionTime).reverse).toNodeSeq(request) } if (failed.nonEmpty) { _content ++= new FailedExecutionTable( parent, s"Failed Queries (${failed.size})", currentTime, - failed.sortBy(_.submissionTime).reverse).toNodeSeq + failed.sortBy(_.submissionTime).reverse).toNodeSeq(request) } _content } @@ -111,7 +111,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L } - UIUtils.headerSparkPage("SQL", summary ++ content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "SQL", summary ++ content, parent, Some(5000)) } } @@ -133,7 +133,10 @@ private[ui] abstract class ExecutionTable( protected def header: Seq[String] - protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = { + protected def row( + request: HttpServletRequest, + currentTime: Long, + executionUIData: SQLExecutionUIData): Seq[Node] = { val submissionTime = executionUIData.submissionTime val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) - submissionTime @@ -141,7 +144,7 @@ private[ui] abstract class ExecutionTable( def jobLinks(status: JobExecutionStatus): Seq[Node] = { executionUIData.jobs.flatMap { case (jobId, jobStatus) => if (jobStatus == status) { - [{jobId.toString}] + [{jobId.toString}] } else { None } @@ -153,7 +156,7 @@ private[ui] abstract class ExecutionTable( {executionUIData.executionId.toString} - {descriptionCell(executionUIData)} + {descriptionCell(request, executionUIData)} {UIUtils.formatDate(submissionTime)} @@ -179,7 +182,9 @@ private[ui] abstract class ExecutionTable( } - private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { + private def descriptionCell( + request: HttpServletRequest, + execution: SQLExecutionUIData): Seq[Node] = { val details = if (execution.details != null && execution.details.nonEmpty) { +details @@ -192,27 +197,28 @@ private[ui] abstract class ExecutionTable( } val desc = if (execution.description != null && execution.description.nonEmpty) { - {execution.description} + {execution.description} } else { - {execution.executionId} + {execution.executionId} }
      {desc} {details}
      } - def toNodeSeq: Seq[Node] = { + def toNodeSeq(request: HttpServletRequest): Seq[Node] = {

      {tableName}

      {UIUtils.listingTable[SQLExecutionUIData]( - header, row(currentTime, _), executionUIDatas, id = Some(tableId))} + header, row(request, currentTime, _), executionUIDatas, id = Some(tableId))}
      } - private def jobURL(jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + private def jobURL(request: HttpServletRequest, jobId: Long): String = + "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) - private def executionURL(executionID: Long): String = - s"${UIUtils.prependBaseUri(parent.basePath)}/${parent.prefix}/execution?id=$executionID" + private def executionURL(request: HttpServletRequest, executionID: Long): String = + s"${UIUtils.prependBaseUri( + request, parent.basePath)}/${parent.prefix}/execution/?id=$executionID" } private[ui] class RunningExecutionTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index e0554f0c4d337..877176b030f8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -49,7 +49,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
    • {label} {jobs.toSeq.sorted.map { jobId => - {jobId.toString}  + {jobId.toString}  }}
    • } else { @@ -77,27 +77,31 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging val graph = sqlStore.planGraph(executionId) summary ++ - planVisualization(metrics, graph) ++ + planVisualization(request, metrics, graph) ++ physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse {
      No information to display for query {executionId}
      } - UIUtils.headerSparkPage(s"Details for Query $executionId", content, parent, Some(5000)) + UIUtils.headerSparkPage( + request, s"Details for Query $executionId", content, parent, Some(5000)) } - private def planVisualizationResources: Seq[Node] = { + private def planVisualizationResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off - - - - - + + + + + // scalastyle:on } - private def planVisualization(metrics: Map[Long, String], graph: SparkPlanGraph): Seq[Node] = { + private def planVisualization( + request: HttpServletRequest, + metrics: Map[Long, String], + graph: SparkPlanGraph): Seq[Node] = { val metadata = graph.allNodes.flatMap { node => val nodeId = s"plan-meta-data-${node.id}"
      {node.desc}
      @@ -112,13 +116,13 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
      {graph.allNodes.size.toString}
      {metadata} - {planVisualizationResources} + {planVisualizationResources(request)} } - private def jobURL(jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + private def jobURL(request: HttpServletRequest, jobId: Long): String = + "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) private def physicalPlanDescription(physicalPlanDescription: String): Seq[Node] = {
      diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2b6bb48467eb3..d254af400a7cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -289,7 +289,7 @@ class SQLAppStatusListener( private def onDriverAccumUpdates(event: SparkListenerDriverAccumUpdates): Unit = { val SparkListenerDriverAccumUpdates(executionId, accumUpdates) = event Option(liveExecutions.get(executionId)).foreach { exec => - exec.driverAccumUpdates = accumUpdates.toMap + exec.driverAccumUpdates = exec.driverAccumUpdates ++ accumUpdates update(exec) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 800a2ea3f3996..626f39d9e95cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -112,9 +112,11 @@ case class WindowExec( * * @param frame to evaluate. This can either be a Row or Range frame. * @param bound with respect to the row. + * @param timeZone the session local timezone for time related calculations. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frame: FrameType, bound: Expression): BoundOrdering = { + private[this] def createBoundOrdering( + frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { (frame, bound) match { case (RowFrame, CurrentRow) => RowBoundOrdering(0) @@ -144,7 +146,7 @@ case class WindowExec( val boundExpr = (expr.dataType, boundOffset.dataType) match { case (DateType, IntegerType) => DateAdd(expr, boundOffset) case (TimestampType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(conf.sessionLocalTimeZone)) + TimeAdd(expr, boundOffset, Some(timeZone)) case (a, b) if a== b => Add(expr, boundOffset) } val bound = newMutableProjection(boundExpr :: Nil, child.output) @@ -197,6 +199,7 @@ case class WindowExec( // Map the groups to a (unbound) expression and frame factory pair. var numExpressions = 0 + val timeZone = conf.sessionLocalTimeZone framedFunctions.toSeq.map { case (key, (expressions, functionSeq)) => val ordinal = numExpressions @@ -237,7 +240,7 @@ case class WindowExec( new UnboundedPrecedingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, upper, timeZone)) } // Shrinking Frame. @@ -246,7 +249,7 @@ case class WindowExec( new UnboundedFollowingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower)) + createBoundOrdering(frameType, lower, timeZone)) } // Moving Frame. @@ -255,8 +258,8 @@ case class WindowExec( new SlidingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower), - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, lower, timeZone), + createBoundOrdering(frameType, upper, timeZone)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bea8c0e445002..b98ab11e56feb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -283,6 +283,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -291,6 +294,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -299,6 +305,9 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -307,6 +316,9 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -422,6 +434,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -435,6 +450,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -448,6 +466,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -459,6 +480,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -535,6 +559,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -548,6 +575,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -561,6 +591,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -572,6 +605,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -775,6 +811,7 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -994,14 +1031,6 @@ object functions { // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Computes the absolute value. - * - * @group normal_funcs - * @since 1.3.0 - */ - def abs(e: Column): Column = withExpr { Abs(e.expr) } - /** * Creates a new array column. The input columns must all have the same data type. * @@ -1033,6 +1062,17 @@ object functions { @scala.annotation.varargs def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } + /** + * Creates a new map column. The array in the first column is used for keys. The array in the + * second column is used for values. All elements in the array for key should not be null. + * + * @group normal_funcs + * @since 2.4 + */ + def map_from_arrays(keys: Column, values: Column): Column = withExpr { + MapFromArrays(keys.expr, values.expr) + } + /** * Marks a DataFrame as small enough for use in broadcast joins. * @@ -1172,7 +1212,7 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * - * @note This is indeterministic when data partitions are not fixed. + * @note The function is non-deterministic in general case. * * @group normal_funcs * @since 1.4.0 @@ -1183,6 +1223,8 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * + * @note The function is non-deterministic in general case. + * * @group normal_funcs * @since 1.4.0 */ @@ -1192,7 +1234,7 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * - * @note This is indeterministic when data partitions are not fixed. + * @note The function is non-deterministic in general case. * * @group normal_funcs * @since 1.4.0 @@ -1203,6 +1245,8 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * + * @note The function is non-deterministic in general case. + * * @group normal_funcs * @since 1.4.0 */ @@ -1211,7 +1255,7 @@ object functions { /** * Partition ID. * - * @note This is indeterministic because it depends on data partitioning and task scheduling. + * @note This is non-deterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs * @since 1.6.0 @@ -1284,7 +1328,7 @@ object functions { } /** - * Computes bitwise NOT. + * Computes bitwise NOT (~) of a number. * * @group normal_funcs * @since 1.4.0 @@ -1312,6 +1356,14 @@ object functions { // Math Functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Computes the absolute value of a numeric value. + * + * @group math_funcs + * @since 1.3.0 + */ + def abs(e: Column): Column = withExpr { Abs(e.expr) } + /** * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` * @@ -2691,11 +2743,27 @@ object functions { /** * Returns number of months between dates `date1` and `date2`. + * If `date1` is later than `date2`, then the result is positive. + * If `date1` and `date2` are on the same day of month, or both are the last day of month, + * time of day will be ignored. + * + * Otherwise, the difference is calculated based on 31 days per month, and rounded to + * 8 digits. * @group datetime_funcs * @since 1.5.0 */ def months_between(date1: Column, date2: Column): Column = withExpr { - MonthsBetween(date1.expr, date2.expr) + new MonthsBetween(date1.expr, date2.expr) + } + + /** + * Returns number of months between dates `date1` and `date2`. If `roundOff` is set to true, the + * result is rounded off to 8 digits; it is not rounded otherwise. + * @group datetime_funcs + * @since 2.4.0 + */ + def months_between(date1: Column, date2: Column, roundOff: Boolean): Column = withExpr { + MonthsBetween(date1.expr, date2.expr, lit(roundOff).expr) } /** @@ -2866,6 +2934,17 @@ object functions { FromUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield + * '2017-07-14 03:40:00.0'. + * @group datetime_funcs + * @since 2.4.0 + */ + def from_utc_timestamp(ts: Column, tz: Column): Column = withExpr { + FromUTCTimestamp(ts.expr, tz.expr) + } + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield @@ -2877,6 +2956,17 @@ object functions { ToUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time + * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield + * '2017-07-14 01:40:00.0'. + * @group datetime_funcs + * @since 2.4.0 + */ + def to_utc_timestamp(ts: Column, tz: Column): Column = withExpr { + ToUTCTimestamp(ts.expr, tz.expr) + } + /** * Bucketize rows into one or more time windows given a timestamp specifying column. Window * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window @@ -3025,7 +3115,47 @@ object functions { * @since 1.5.0 */ def array_contains(column: Column, value: Any): Column = withExpr { - ArrayContains(column.expr, Literal(value)) + ArrayContains(column.expr, lit(value).expr) + } + + /** + * Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and both + * the arrays are non-empty and any of them contains a `null`, it returns `null`. It returns + * `false` otherwise. + * @group collection_funcs + * @since 2.4.0 + */ + def arrays_overlap(a1: Column, a2: Column): Column = withExpr { + ArraysOverlap(a1.expr, a2.expr) + } + + /** + * Returns an array containing all the elements in `x` from index `start` (or starting from the + * end if `start` is negative) with the specified `length`. + * @group collection_funcs + * @since 2.4.0 + */ + def slice(x: Column, start: Int, length: Int): Column = withExpr { + Slice(x.expr, Literal(start), Literal(length)) + } + + /** + * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + * `nullReplacement`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String, nullReplacement: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), Some(Literal(nullReplacement))) + } + + /** + * Concatenates the elements of `column` using the `delimiter`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), None) } /** @@ -3049,7 +3179,7 @@ object functions { * @since 2.4.0 */ def array_position(column: Column, value: Any): Column = withExpr { - ArrayPosition(column.expr, Literal(value)) + ArrayPosition(column.expr, lit(value).expr) } /** @@ -3060,7 +3190,43 @@ object functions { * @since 2.4.0 */ def element_at(column: Column, value: Any): Column = withExpr { - ElementAt(column.expr, Literal(value)) + ElementAt(column.expr, lit(value).expr) + } + + /** + * Sorts the input array in ascending order. The elements of the input array must be orderable. + * Null elements will be placed at the end of the returned array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + + /** + * Remove all elements that equal to element from the given array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_remove(column: Column, element: Any): Column = withExpr { + ArrayRemove(column.expr, lit(element).expr) + } + + /** + * Removes duplicate values from the array. + * @group collection_funcs + * @since 2.4.0 + */ + def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + + /** + * Returns an array of the elements in the union of the given two arrays, without duplicates. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_union(col1: Column, col2: Column): Column = withExpr { + ArrayUnion(col1.expr, col2.expr) } /** @@ -3136,9 +3302,9 @@ object functions { from_json(e, schema.asInstanceOf[DataType], options) /** - * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3168,9 +3334,9 @@ object functions { from_json(e, schema, options.asScala.toMap) /** - * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3197,8 +3363,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s - * with the specified schema. Returns `null`, in the case of an unparseable string. + * Parses a column containing a JSON string into a `MapType` with `StringType` as keys type, + * `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3210,9 +3377,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, @@ -3227,9 +3394,9 @@ object functions { } /** - * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string, it could be a @@ -3242,11 +3409,53 @@ object functions { val dataType = try { DataType.fromJson(schema) } catch { - case NonFatal(_) => StructType.fromDDL(schema) + case NonFatal(_) => DataType.fromDDL(schema) } from_json(e, dataType, options) } + /** + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column): Column = { + from_json(e, schema, Map.empty[String, String].asJava) + } + + /** + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { + withExpr(new JsonToStructs(e.expr, schema.expr, options.asScala.toMap)) + } + + /** + * Parses a column containing a JSON string and infers its schema. + * + * @param e a string column containing JSON data. + * + * @group collection_funcs + * @since 2.4.0 + */ + def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr)) + /** * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s, * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. @@ -3302,6 +3511,7 @@ object functions { /** * Sorts the input array for the given column in ascending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array. * * @group collection_funcs * @since 1.5.0 @@ -3311,6 +3521,8 @@ object functions { /** * Sorts the input array for the given column in ascending or descending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array in ascending order or + * at the end of the returned array in descending order. * * @group collection_funcs * @since 1.5.0 @@ -3340,6 +3552,55 @@ object functions { */ def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + /** + * Creates a single array from an array of arrays. If a structure of nested arrays is deeper than + * two levels, only one level of nesting is removed. + * @group collection_funcs + * @since 2.4.0 + */ + def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + + /** + * Generate a sequence of integers from start to stop, incrementing by step. + * + * @group collection_funcs + * @since 2.4.0 + */ + def sequence(start: Column, stop: Column, step: Column): Column = withExpr { + new Sequence(start.expr, stop.expr, step.expr) + } + + /** + * Generate a sequence of integers from start to stop, + * incrementing by 1 if start is less than or equal to stop, otherwise -1. + * + * @group collection_funcs + * @since 2.4.0 + */ + def sequence(start: Column, stop: Column): Column = withExpr { + new Sequence(start.expr, stop.expr) + } + + /** + * Creates an array containing the left argument repeated the number of times given by the + * right argument. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_repeat(left: Column, right: Column): Column = withExpr { + ArrayRepeat(left.expr, right.expr) + } + + /** + * Creates an array containing the left argument repeated the number of times given by the + * right argument. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_repeat(e: Column, count: Int): Column = array_repeat(e, lit(count)) + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs @@ -3354,6 +3615,156 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } + /** + * Returns an unordered array of all entries in the given map. + * @group collection_funcs + * @since 2.4.0 + */ + def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + + /** + * Returns a map created from the given array of entries. + * @group collection_funcs + * @since 2.4.0 + */ + def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) } + + /** + * Returns a merged array of structs in which the N-th struct contains all N-th values of input + * arrays. + * @group collection_funcs + * @since 2.4.0 + */ + @scala.annotation.varargs + def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } + + /** + * Returns the union of all the given maps. + * @group collection_funcs + * @since 2.4.0 + */ + @scala.annotation.varargs + def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Mask functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Returns a string which is the masked representation of the input. + * @group mask_funcs + * @since 2.4.0 + */ + def mask(e: Column): Column = withExpr { new Mask(e.expr) } + + /** + * Returns a string which is the masked representation of the input, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask(e: Column, upper: String, lower: String, digit: String): Column = withExpr { + Mask(e.expr, upper, lower, digit) + } + + /** + * Returns a string with the first `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_first_n(e: Column, n: Int): Column = withExpr { new MaskFirstN(e.expr, Literal(n)) } + + /** + * Returns a string with the first `n` characters masked, using `upper`, `lower` and `digit` as + * replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_first_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskFirstN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with the last `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_last_n(e: Column, n: Int): Column = withExpr { new MaskLastN(e.expr, Literal(n)) } + + /** + * Returns a string with the last `n` characters masked, using `upper`, `lower` and `digit` as + * replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_last_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskLastN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with all but the first `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_first_n(e: Column, n: Int): Column = withExpr { + new MaskShowFirstN(e.expr, Literal(n)) + } + + /** + * Returns a string with all but the first `n` characters masked, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_first_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskShowFirstN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with all but the last `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_last_n(e: Column, n: Int): Column = withExpr { + new MaskShowLastN(e.expr, Literal(n)) + } + + /** + * Returns a string with all but the last `n` characters masked, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_last_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskShowLastN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a hashed value based on the input column. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_hash(e: Column): Column = withExpr { MaskHash(e.expr) } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 6ae307bce10c8..4698e8ab13ce3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -364,7 +364,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession, viewDef, cascade = false, blocking = true) sessionCatalog.dropTempView(viewName) } } @@ -379,7 +380,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropGlobalTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession, viewDef, cascade = false, blocking = true) sessionCatalog.dropGlobalTempView(viewName) } } @@ -438,7 +440,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + val cascade = !sessionCatalog.isTemporaryTable(tableIdent) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName), cascade) } /** @@ -490,7 +494,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // cached version and make the new version cached lazily. if (isCached(table)) { // Uncache the logicalPlan. - sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery(table, cascade = true, blocking = true) // Cache it again. sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index baea4ceebf8e3..5b6160e2b408f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -99,7 +99,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - lazy val externalCatalog: ExternalCatalog = { + lazy val externalCatalog: ExternalCatalogWithListener = { val externalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, @@ -117,14 +117,17 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) } + // Wrap to provide catalog events + val wrapped = new ExternalCatalogWithListener(externalCatalog) + // Make sure we propagate external catalog events to the spark listener bus - externalCatalog.addListener(new ExternalCatalogEventListener { + wrapped.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { sparkContext.listenerBus.post(event) } }) - externalCatalog + wrapped } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 2499e9b604f3e..bdd8c4da6bd30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -199,7 +199,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { /** * A filter that evaluates to `true` iff the attribute evaluates to - * a string that starts with `value`. + * a string that ends with `value`. * * @since 1.3.1 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ae93965bc50ed..ef8dc3a325a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -270,6 +270,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * per file *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
    • + *
    • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
    • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index effc1471e8e12..3b9a56ffdde4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,14 +21,15 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} +import org.apache.spark.annotation.{InterfaceStability, Since} +import org.apache.spark.api.java.function.VoidFunction2 +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.sources.v2.StreamWriteSupport /** @@ -269,7 +270,22 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) + val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) + df.sparkSession.sessionState.streamingQueryManager.startQuery( + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), + df, + extraOptions.toMap, + sink, + outputMode, + useTempCheckpointLocation = true, + trigger = trigger) + } else if (source == "foreachBatch") { + assertNotPartitioned("foreachBatch") + if (trigger.isInstanceOf[ContinuousTrigger]) { + throw new AnalysisException("'foreachBatch' is not supported with continuous trigger") + } + val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -307,49 +323,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * Starts the execution of the streaming query, which will continually send results to the given - * `ForeachWriter` as new data arrives. The `ForeachWriter` can be used to send the data - * generated by the `DataFrame`/`Dataset` to an external system. - * - * Scala example: - * {{{ - * datasetOfString.writeStream.foreach(new ForeachWriter[String] { - * - * def open(partitionId: Long, version: Long): Boolean = { - * // open connection - * } - * - * def process(record: String) = { - * // write string to connection - * } - * - * def close(errorOrNull: Throwable): Unit = { - * // close the connection - * } - * }).start() - * }}} - * - * Java example: - * {{{ - * datasetOfString.writeStream().foreach(new ForeachWriter() { - * - * @Override - * public boolean open(long partitionId, long version) { - * // open connection - * } - * - * @Override - * public void process(String value) { - * // write string to connection - * } - * - * @Override - * public void close(Throwable errorOrNull) { - * // close the connection - * } - * }).start(); - * }}} - * + * Sets the output of the streaming query to be processed using the provided writer object. + * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and + * semantics. * @since 2.0.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { @@ -362,6 +338,45 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { this } + /** + * :: Experimental :: + * + * (Scala-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only the in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. + * The batchId can be used deduplicate and transactionally write the output + * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed + * to exactly same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @InterfaceStability.Evolving + def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { + this.source = "foreachBatch" + if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") + this.foreachBatchWriter = function + this + } + + /** + * :: Experimental :: + * + * (Java-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only the in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. + * The batchId can be used deduplicate and transactionally write the output + * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed + * to exactly same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @InterfaceStability.Evolving + def foreachBatch(function: VoidFunction2[Dataset[T], Long]): DataStreamWriter[T] = { + foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) + } + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => cols.map(normalize(_, "Partition")) } @@ -398,5 +413,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var foreachWriter: ForeachWriter[T] = null + private var foreachBatchWriter: (Dataset[T], Long) => Unit = null + private var partitioningColumns: Option[Seq[String]] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 7cefd03e43bc3..25bb05212d66f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} @@ -32,6 +33,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS import org.apache.spark.sql.sources.v2.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -55,6 +57,19 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo @GuardedBy("awaitTerminationLock") private var lastTerminatedQuery: StreamingQuery = null + try { + sparkSession.sparkContext.conf.get(STREAMING_QUERY_LISTENERS).foreach { classNames => + Utils.loadExtensions(classOf[StreamingQueryListener], classNames, + sparkSession.sparkContext.conf).foreach(listener => { + addListener(listener) + logInfo(s"Registered listener ${listener.getClass.getName}") + }) + } + } catch { + case e: Exception => + throw new SparkException("Exception when registering StreamingQueryListener", e) + } + /** * Returns a list of active queries associated with this SQLContext * @@ -242,7 +257,9 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo (sink, trigger) match { case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => - UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + } new StreamingQueryWrapper(new ContinuousExecution( sparkSession, userSpecifiedName.orNull, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index c132cab1b38cf..2c695fc58fd8c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -34,6 +34,7 @@ import org.junit.*; import org.junit.rules.ExpectedException; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; @@ -336,6 +337,23 @@ public void testTupleEncoder() { Assert.assertEquals(data5, ds5.collectAsList()); } + @Test + public void testTupleEncoderSchema() { + Encoder>> encoder = + Encoders.tuple(Encoders.STRING(), Encoders.tuple(Encoders.STRING(), Encoders.STRING())); + List>> data = Arrays.asList(tuple2("1", tuple2("a", "b")), + tuple2("2", tuple2("c", "d"))); + Dataset ds1 = spark.createDataset(data, encoder).toDF("value1", "value2"); + + JavaPairRDD> pairRDD = jsc.parallelizePairs(data); + Dataset ds2 = spark.createDataset(JavaPairRDD.toRDD(pairRDD), encoder) + .toDF("value1", "value2"); + + Assert.assertEquals(ds1.schema(), ds2.schema()); + Assert.assertEquals(ds1.select(expr("value2._1")).collectAsList(), + ds2.select(expr("value2._1")).collectAsList()); + } + @Test public void testNestedTupleEncoder() { // test ((int, string), string) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java new file mode 100644 index 0000000000000..a19ddbdbadba2 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.execution.sort; + +import org.apache.spark.SparkConf; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryConsumer; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.execution.RecordBinaryComparator; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.collection.unsafe.sort.*; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test the RecordBinaryComparator, which compares two UnsafeRows by their binary form. + */ +public class RecordBinaryComparatorSuite { + + private final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + private final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + + private final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + + private MemoryBlock dataPage; + private long pageCursor; + + private LongArray array; + private int pos; + + @Before + public void beforeEach() { + // Only compare between two input rows. + array = consumer.allocateArray(2); + pos = 0; + + dataPage = memoryManager.allocatePage(4096, consumer); + pageCursor = dataPage.getBaseOffset(); + } + + @After + public void afterEach() { + consumer.freePage(dataPage); + dataPage = null; + pageCursor = 0; + + consumer.freeArray(array); + array = null; + pos = 0; + } + + private void insertRow(UnsafeRow row) { + Object recordBase = row.getBaseObject(); + long recordOffset = row.getBaseOffset(); + int recordLength = row.getSizeInBytes(); + + Object baseObject = dataPage.getBaseObject(); + assert(pageCursor + recordLength <= dataPage.getBaseOffset() + dataPage.size()); + long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, pageCursor); + UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength); + pageCursor += uaoSize; + Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, recordLength); + pageCursor += recordLength; + + assert(pos < 2); + array.set(pos, recordAddress); + pos++; + } + + private int compare(int index1, int index2) { + Object baseObject = dataPage.getBaseObject(); + + long recordAddress1 = array.get(index1); + long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) + uaoSize; + int recordLength1 = UnsafeAlignedOffset.getSize(baseObject, baseOffset1 - uaoSize); + + long recordAddress2 = array.get(index2); + long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) + uaoSize; + int recordLength2 = UnsafeAlignedOffset.getSize(baseObject, baseOffset2 - uaoSize); + + return binaryComparator.compare(baseObject, baseOffset1, recordLength1, baseObject, + baseOffset2, recordLength2); + } + + private final RecordComparator binaryComparator = new RecordBinaryComparator(); + + // Compute the most compact size for UnsafeRow's backing data. + private int computeSizeInBytes(int originalSize) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + (originalSize + 7) / 8 * 8; + } + + // Compute the relative offset of variable-length values. + private long relativeOffset(int numFields) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + numFields * 8L; + } + + @Test + public void testBinaryComparatorForSingleColumnRow() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 42); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForMultipleColumnRow() throws Exception { + int numFields = 5; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setDouble(i, i * 3.14); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row2.setDouble(i, 198.7 / (i + 1)); + } + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForArrayColumn() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UnsafeArrayData arrayData1 = UnsafeArrayData.fromPrimitiveArray(new int[]{11, 42, -1}); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + arrayData1.getSizeInBytes())); + row1.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData1.getSizeInBytes()); + Platform.copyMemory(arrayData1.getBaseObject(), arrayData1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), arrayData1.getSizeInBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UnsafeArrayData arrayData2 = UnsafeArrayData.fromPrimitiveArray(new int[]{22}); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + arrayData2.getSizeInBytes())); + row2.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData2.getSizeInBytes()); + Platform.copyMemory(arrayData2.getBaseObject(), arrayData2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), arrayData2.getSizeInBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForMixedColumns() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UTF8String str1 = UTF8String.fromString("Milk tea"); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + str1.numBytes())); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, (relativeOffset(numFields) << 32) | (long) str1.numBytes()); + Platform.copyMemory(str1.getBaseObject(), str1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), str1.numBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UTF8String str2 = UTF8String.fromString("Java"); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + str2.numBytes())); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, (relativeOffset(numFields) << 32) | (long) str2.numBytes()); + Platform.copyMemory(str2.getBaseObject(), str2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), str2.numBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForNullColumns() throws Exception { + int numFields = 3; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setNullAt(i); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields - 1; i++) { + row2.setNullAt(i); + } + row2.setDouble(numFields - 1, 3.14); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 172e5d5eebcbe..445cb29f5ee3a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -79,8 +79,8 @@ public Filter[] pushedFilters() { } @Override - public List> createDataReaderFactories() { - List> res = new ArrayList<>(); + public List> planInputPartitions() { + List> res = new ArrayList<>(); Integer lowerBound = null; for (Filter filter : filters) { @@ -94,33 +94,34 @@ public List> createDataReaderFactories() { } if (lowerBound == null) { - res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema)); - res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); + res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema)); + res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); } else if (lowerBound < 4) { - res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); + res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); } else if (lowerBound < 9) { - res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema)); + res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema)); } return res; } } - static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader { + static class JavaAdvancedInputPartition implements InputPartition, + InputPartitionReader { private int start; private int end; private StructType requiredSchema; - JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) { + JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) { this.start = start; this.end = end; this.requiredSchema = requiredSchema; } @Override - public DataReader createDataReader() { - return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema); + public InputPartitionReader createPartitionReader() { + return new JavaAdvancedInputPartition(start - 1, end, requiredSchema); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index c55093768105b..97d6176d02559 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -42,14 +42,14 @@ public StructType readSchema() { } @Override - public List> createBatchDataReaderFactories() { + public List> planBatchInputPartitions() { return java.util.Arrays.asList( - new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90)); + new JavaBatchInputPartition(0, 50), new JavaBatchInputPartition(50, 90)); } } - static class JavaBatchDataReaderFactory - implements DataReaderFactory, DataReader { + static class JavaBatchInputPartition + implements InputPartition, InputPartitionReader { private int start; private int end; @@ -59,13 +59,13 @@ static class JavaBatchDataReaderFactory private OnHeapColumnVector j; private ColumnarBatch batch; - JavaBatchDataReaderFactory(int start, int end) { + JavaBatchInputPartition(int start, int end) { this.start = start; this.end = end; } @Override - public DataReader createDataReader() { + public InputPartitionReader createPartitionReader() { this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); ColumnVector[] vectors = new ColumnVector[2]; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 32fad59b97ff6..e49c8cf8b9e16 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -43,10 +43,10 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List> planInputPartitions() { return java.util.Arrays.asList( - new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}), - new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); } @Override @@ -73,12 +73,12 @@ public boolean satisfy(Distribution distribution) { } } - static class SpecificDataReaderFactory implements DataReaderFactory, DataReader { + static class SpecificInputPartition implements InputPartition, InputPartitionReader { private int[] i; private int[] j; private int current = -1; - SpecificDataReaderFactory(int[] i, int[] j) { + SpecificInputPartition(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; @@ -101,7 +101,7 @@ public void close() throws IOException { } @Override - public DataReader createDataReader() { + public InputPartitionReader createPartitionReader() { return this; } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 048d078dfaac4..80eeffd95f83b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -24,7 +24,7 @@ import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { @@ -42,7 +42,7 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List> planInputPartitions() { return java.util.Collections.emptyList(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 96f55b8a76811..8522a63898a3b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -25,8 +25,8 @@ import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.DataReader; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; @@ -41,25 +41,25 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List> planInputPartitions() { return java.util.Arrays.asList( - new JavaSimpleDataReaderFactory(0, 5), - new JavaSimpleDataReaderFactory(5, 10)); + new JavaSimpleInputPartition(0, 5), + new JavaSimpleInputPartition(5, 10)); } } - static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader { + static class JavaSimpleInputPartition implements InputPartition, InputPartitionReader { private int start; private int end; - JavaSimpleDataReaderFactory(int start, int end) { + JavaSimpleInputPartition(int start, int end) { this.start = start; this.end = end; } @Override - public DataReader createDataReader() { - return new JavaSimpleDataReaderFactory(start - 1, end); + public InputPartitionReader createPartitionReader() { + return new JavaSimpleInputPartition(start - 1, end); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index c3916e0b370b5..3ad8e7a0104ce 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -38,20 +38,20 @@ public StructType readSchema() { } @Override - public List> createUnsafeRowReaderFactories() { + public List> planUnsafeInputPartitions() { return java.util.Arrays.asList( - new JavaUnsafeRowDataReaderFactory(0, 5), - new JavaUnsafeRowDataReaderFactory(5, 10)); + new JavaUnsafeRowInputPartition(0, 5), + new JavaUnsafeRowInputPartition(5, 10)); } } - static class JavaUnsafeRowDataReaderFactory - implements DataReaderFactory, DataReader { + static class JavaUnsafeRowInputPartition + implements InputPartition, InputPartitionReader { private int start; private int end; private UnsafeRow row; - JavaUnsafeRowDataReaderFactory(int start, int end) { + JavaUnsafeRowInputPartition(int start, int end) { this.start = start; this.end = end; this.row = new UnsafeRow(2); @@ -59,8 +59,8 @@ static class JavaUnsafeRowDataReaderFactory } @Override - public DataReader createDataReader() { - return new JavaUnsafeRowDataReaderFactory(start - 1, end); + public InputPartitionReader createPartitionReader() { + return new JavaUnsafeRowInputPartition(start - 1, end); } @Override diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 547c2bef02b24..4950a4b7a4e5a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -27,3 +27,36 @@ select current_date = current_date(), current_timestamp = current_timestamp(), a select a, b from ttf2 order by a, current_date; select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); + +select from_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select from_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select from_utc_timestamp(null, 'PST'); + +select from_utc_timestamp('2015-07-24 00:00:00', null); + +select from_utc_timestamp(null, null); + +select from_utc_timestamp(cast(0 as timestamp), 'PST'); + +select from_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select to_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select to_utc_timestamp(null, 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', null); + +select to_utc_timestamp(null, null); + +select to_utc_timestamp(cast(0 as timestamp), 'PST'); + +select to_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +-- SPARK-23715: the input of to/from_utc_timestamp can not have timezone +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); + +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); diff --git a/sql/core/src/test/resources/sql-tests/inputs/extract.sql b/sql/core/src/test/resources/sql-tests/inputs/extract.sql new file mode 100644 index 0000000000000..9adf5d70056e2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/extract.sql @@ -0,0 +1,21 @@ +CREATE TEMPORARY VIEW t AS select '2011-05-06 07:08:09.1234567' as c; + +select extract(year from c) from t; + +select extract(quarter from c) from t; + +select extract(month from c) from t; + +select extract(week from c) from t; + +select extract(day from c) from t; + +select extract(dayofweek from c) from t; + +select extract(hour from c) from t; + +select extract(minute from c) from t; + +select extract(second from c) from t; + +select extract(not_supported from c) from t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c5070b734d521..2c18d6aaabdba 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -68,4 +68,8 @@ SELECT 1 from ( FROM (select 1 as x) a WHERE false ) b -where b.z != b.z +where b.z != b.z; + +-- SPARK-24369 multiple distinct aggregations having the same argument set +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql index 8afa3270f4de4..2e6a5f362a8fa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a); CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a); diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index fea069eac4d48..79fdd5895e691 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -31,3 +31,11 @@ CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable; -- Clean up DROP VIEW IF EXISTS jsonTable; + +-- from_json - complex types +select from_json('{"a":1, "b":2}', 'map'); +select from_json('{"a":1, "b":"2"}', 'struct'); + +-- infer schema of json literal +select schema_of_json('{"c1":0, "c2":[1]}'); +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')); diff --git a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql index 71a50157b766c..e0abeda3eb44f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + create temporary view nt1 as select * from values ("one", 1), ("two", 2), diff --git a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql index cdc6c81e10047..ce09c21568f13 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + -- SPARK-17099: Incorrect result when HAVING clause is added to group by query CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (-234), (145), (367), (975), (298) diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql new file mode 100644 index 0000000000000..b3d53adfbebe7 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -0,0 +1,131 @@ +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings); + +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s); + +-- pivot courses +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot years with no subquery +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); + +-- pivot courses with multiple aggregations +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column and with multiple aggregations on different columns +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple group by columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +); + +-- pivot on join query with multiple aggregations on different columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple columns in one aggregation +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with aliases and projection +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +); + +-- pivot years with non-aggregate function +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +); + +-- pivot with unresolvable columns +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); + +-- pivot with complex aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +); + +-- pivot with invalid arguments in aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql index cc4ed64affec7..cefc3fe6272ab 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql @@ -1,5 +1,9 @@ -- Tests EXISTS subquery support. Tests Exists subquery -- used in Joins (Both when joins occurs in outer and suquery blocks) +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false CREATE TEMPORARY VIEW EMP AS SELECT * FROM VALUES (100, "emp 1", date "2005-01-01", 100.00D, 10), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql index 880175fd7add0..22f3eafd6a02d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql @@ -1,5 +1,9 @@ -- A test suite for IN JOINS in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql index e09b91f18de0a..4f8ca8bfb27c1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql @@ -1,5 +1,9 @@ -- A test suite for not-in-joins in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql new file mode 100644 index 0000000000000..8eea84f4f5272 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql @@ -0,0 +1,39 @@ +-- Unit tests for simple NOT IN predicate subquery across multiple columns. +-- +-- See not-in-single-column-unit-tests.sql for an introduction. +-- This file has the same test cases as not-in-unit-tests-multi-column.sql with literals instead of +-- subqueries. Small changes have been made to the literals to make them typecheck. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +-- Case 1 (not possible to write a literal with no rows, so we ignore it.) +-- (subquery is empty -> row is returned) + +-- Cases 2, 3 and 4 are currently broken, so I have commented them out here. +-- Filed https://issues.apache.org/jira/browse/SPARK-24395 to fix and restore these test cases. + + -- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((2, 3.0)); + + -- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN ((2, 3.0)); + + -- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN ((2, 3.0)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql new file mode 100644 index 0000000000000..9f8dc7fca3b94 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql @@ -0,0 +1,98 @@ +-- Unit tests for simple NOT IN predicate subquery across multiple columns. +-- +-- See not-in-single-column-unit-tests.sql for an introduction. +-- +-- Test cases for multi-column ``WHERE a NOT IN (SELECT c FROM r ...)'': +-- | # | does subquery include null? | do filter columns contain null? | a = c? | b = d? | row included in result? | +-- | 1 | empty | * | * | * | yes | +-- | 2 | 1+ row has null for all columns | * | * | * | no | +-- | 3 | no row has null for all columns | (yes, yes) | * | * | no | +-- | 4 | no row has null for all columns | (no, yes) | yes | * | no | +-- | 5 | no row has null for all columns | (no, yes) | no | * | yes | +-- | 6 | no | (no, no) | yes | yes | no | +-- | 7 | no | (no, no) | _ | _ | yes | +-- +-- This can be generalized to include more tests for more columns, but it covers the main cases +-- when there is more than one column. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, null), + (0, 1.0), + (2, 3.0), + (4, null) + AS s(c, d); + + -- Case 1 + -- (subquery is empty -> row is returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE d > 5.0) -- Matches no rows +; + + -- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NULL AND d IS NULL) -- Matches only (null, null) +; + + -- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +; + + -- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +; + + -- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; + + -- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; + + -- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql new file mode 100644 index 0000000000000..b261363d1dde7 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql @@ -0,0 +1,42 @@ +-- Unit tests for simple NOT IN with a literal expression of a single column +-- +-- More information can be found in not-in-unit-tests-single-column.sql. +-- This file has the same test cases as not-in-unit-tests-single-column.sql with literals instead of +-- subqueries. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + + -- Uncorrelated NOT IN Subquery test cases + -- Case 1 (not possible to write a literal with no rows, so we ignore it.) + -- (empty subquery -> all rows returned) + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (null); + + -- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (2); + + -- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (2); + + -- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (6); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql new file mode 100644 index 0000000000000..2cc08e10acf67 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql @@ -0,0 +1,123 @@ +-- Unit tests for simple NOT IN predicate subquery across a single column. +-- +-- ``col NOT IN expr'' is quite difficult to reason about. There are many edge cases, some of the +-- rules are confusing to the uninitiated, and precedence and treatment of null values is plain +-- unintuitive. To make this simpler to understand, I've come up with a plain English way of +-- describing the expected behavior of this query. +-- +-- - If the subquery is empty (i.e. returns no rows), the row should be returned, regardless of +-- whether the filtered columns include nulls. +-- - If the subquery contains a result with all columns null, then the row should not be returned. +-- - If for all non-null filter columns there exists a row in the subquery in which each column +-- either +-- 1. is equal to the corresponding filter column or +-- 2. is null +-- then the row should not be returned. (This includes the case where all filter columns are +-- null.) +-- - Otherwise, the row should be returned. +-- +-- Using these rules, we can come up with a set of test cases for single-column and multi-column +-- NOT IN test cases. +-- +-- Test cases for single-column ``WHERE a NOT IN (SELECT c FROM r ...)'': +-- | # | does subquery include null? | is a null? | a = c? | row with a included in result? | +-- | 1 | empty | | | yes | +-- | 2 | yes | | | no | +-- | 3 | no | yes | | no | +-- | 4 | no | no | yes | no | +-- | 5 | no | no | no | yes | +-- +-- There are also some considerations around correlated subqueries. Correlated subqueries can +-- cause cases 2, 3, or 4 to be reduced to case 1 by limiting the number of rows returned by the +-- subquery, so the row from the parent table should always be included in the output. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (6, 7.0) + AS s(c, d); + + -- Uncorrelated NOT IN Subquery test cases + -- Case 1 + -- (empty subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d > 10.0) -- (empty subquery) +; + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = 1.0) -- Only matches (null, 1.0) +; + + -- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +; + + -- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +; + + -- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 7.0) -- Matches (6, 7.0) +; + + -- Correlated NOT IN subquery test cases + -- Case 2->1 + -- (subquery had nulls but they are removed by correlated subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; + + -- Case 3->1 + -- (probe column is null but subquery returns no rows -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; + + -- Case 4->1 + -- (probe column matches row which is filtered out by correlated subquery -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql new file mode 100644 index 0000000000000..99729c007b104 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql @@ -0,0 +1,11 @@ +SELECT array_join(array(true, false), ', '); +SELECT array_join(array(2Y, 1Y), ', '); +SELECT array_join(array(2S, 1S), ', '); +SELECT array_join(array(2, 1), ', '); +SELECT array_join(array(2L, 1L), ', '); +SELECT array_join(array(9223372036854775809, 9223372036854775808), ', '); +SELECT array_join(array(2.0D, 1.0D), ', '); +SELECT array_join(array(float(2.0), float(1.0)), ', '); +SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', '); +SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', '); +SELECT array_join(array('a', 'b'), ', '); diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index 9be7fcdadfea8..28a0e20c0f495 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -40,12 +40,14 @@ select 10.3000 * 3.0; select 10.30000 * 30.0; select 10.300000000000000000 * 3.000000000000000000; select 10.300000000000000000 * 3.0000000000000000000; +select 2.35E10 * 1.0; -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; +select 1.2345678901234567890E30 * 1.2345678901234567890E25; -- arithmetic operations causing a precision loss are truncated select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; @@ -67,12 +69,14 @@ select 10.3000 * 3.0; select 10.30000 * 30.0; select 10.300000000000000000 * 3.000000000000000000; select 10.300000000000000000 * 3.0000000000000000000; +select 2.35E10 * 1.0; -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; +select 1.2345678901234567890E30 * 1.2345678901234567890E25; -- arithmetic operations causing a precision loss return NULL select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql new file mode 100644 index 0000000000000..fc26397b881b5 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -0,0 +1,94 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +); + +-- Concatenate maps of the same type +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps; + +-- Concatenate maps of different types +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps; + +-- Concatenate map of incompatible types 1 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps; + +-- Concatenate map of incompatible types 2 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps; + +-- Concatenate map of incompatible types 3 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps; + +-- Concatenate map of incompatible types 4 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps; + +-- Concatenate map of incompatible types 5 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql new file mode 100644 index 0000000000000..92c7e26e3add2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql @@ -0,0 +1,56 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (101, 1, 1, 1), + (201, 2, 1, 1), + (301, 3, 1, 1), + (401, 4, 1, 11), + (501, 5, 1, null), + (601, 6, null, 1), + (701, 6, null, null), + (102, 1, 2, 2), + (202, 2, 1, 2), + (302, 3, 2, 1), + (402, 4, 2, 12), + (502, 5, 2, null), + (602, 6, null, 2), + (702, 6, null, null), + (103, 1, 3, 3), + (203, 2, 1, 3), + (303, 3, 3, 1), + (403, 4, 3, 13), + (503, 5, 3, null), + (603, 6, null, 3), + (703, 6, null, null), + (104, 1, 4, 4), + (204, 2, 1, 4), + (304, 3, 4, 1), + (404, 4, 4, 14), + (504, 5, 4, null), + (604, 6, null, 4), + (704, 6, null, null), + (800, 7, 1, 1) +as t1(id, px, y, x); + +select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), + regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), + regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) +from t1 group by px order by px; + + +select id, regr_count(y,x) over (partition by px) from t1 order by id; diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql index e57d69eaad033..6da1b9b49b226 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/union.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -35,6 +35,17 @@ FROM (SELECT col AS col SELECT col FROM p3) T1) T2; +-- SPARK-24012 Union of map and other compatible columns. +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1; + +-- SPARK-24012 Union of array and other compatible columns. +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1; + + -- Clean-up DROP VIEW IF EXISTS t1; DROP VIEW IF EXISTS t2; diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 4e1cfa6e48c1c..9eede305dbdcc 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 26 -- !query 0 @@ -82,9 +82,138 @@ struct 1 2 2 3 + -- !query 9 select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') --- !query 3 schema +-- !query 9 schema struct --- !query 3 output +-- !query 9 output 5 3 5 NULL 4 + + +-- !query 10 +select from_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 10 schema +struct +-- !query 10 output +2015-07-23 17:00:00 + + +-- !query 11 +select from_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 11 schema +struct +-- !query 11 output +2015-01-23 16:00:00 + + +-- !query 12 +select from_utc_timestamp(null, 'PST') +-- !query 12 schema +struct +-- !query 12 output +NULL + + +-- !query 13 +select from_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 13 schema +struct +-- !query 13 output +NULL + + +-- !query 14 +select from_utc_timestamp(null, null) +-- !query 14 schema +struct +-- !query 14 output +NULL + + +-- !query 15 +select from_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 15 schema +struct +-- !query 15 output +1969-12-31 08:00:00 + + +-- !query 16 +select from_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 16 schema +struct +-- !query 16 output +2015-01-23 16:00:00 + + +-- !query 17 +select to_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 17 schema +struct +-- !query 17 output +2015-07-24 07:00:00 + + +-- !query 18 +select to_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 18 schema +struct +-- !query 18 output +2015-01-24 08:00:00 + + +-- !query 19 +select to_utc_timestamp(null, 'PST') +-- !query 19 schema +struct +-- !query 19 output +NULL + + +-- !query 20 +select to_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select to_utc_timestamp(null, null) +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select to_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 22 schema +struct +-- !query 22 output +1970-01-01 00:00:00 + + +-- !query 23 +select to_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 23 schema +struct +-- !query 23 output +2015-01-24 08:00:00 + + +-- !query 24 +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 24 schema +struct +-- !query 24 output +NULL + + +-- !query 25 +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 25 schema +struct +-- !query 25 output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out index 51dac111029e8..8ba69c698b551 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -57,6 +57,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -89,7 +91,9 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Created Time [not included in comparison] +Last Access [not included in comparison] +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -122,7 +126,9 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Created Time [not included in comparison] +Last Access [not included in comparison] +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -147,7 +153,9 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 -Partition Statistics 1080 bytes, 4 rows +Created Time [not included in comparison] +Last Access [not included in comparison] +Partition Statistics 1098 bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -180,7 +188,9 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Created Time [not included in comparison] +Last Access [not included in comparison] +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -205,7 +215,9 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 -Partition Statistics 1080 bytes, 4 rows +Created Time [not included in comparison] +Last Access [not included in comparison] +Partition Statistics 1098 bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -230,7 +242,9 @@ Database default Table t Partition Values [ds=2017-09-01, hr=5] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 -Partition Statistics 1054 bytes, 2 rows +Created Time [not included in comparison] +Last Access [not included in comparison] +Partition Statistics 1144 bytes, 2 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 8c908b7625056..79390cb424444 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -282,6 +282,8 @@ Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 Storage Properties [a=1, b=2] +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Num Buckets 2 @@ -311,6 +313,8 @@ Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 Storage Properties [a=1, b=2] +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Num Buckets 2 diff --git a/sql/core/src/test/resources/sql-tests/results/extract.sql.out b/sql/core/src/test/resources/sql-tests/results/extract.sql.out new file mode 100644 index 0000000000000..160e4c7d78455 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/extract.sql.out @@ -0,0 +1,96 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS select '2011-05-06 07:08:09.1234567' as c +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select extract(year from c) from t +-- !query 1 schema +struct +-- !query 1 output +2011 + + +-- !query 2 +select extract(quarter from c) from t +-- !query 2 schema +struct +-- !query 2 output +2 + + +-- !query 3 +select extract(month from c) from t +-- !query 3 schema +struct +-- !query 3 output +5 + + +-- !query 4 +select extract(week from c) from t +-- !query 4 schema +struct +-- !query 4 output +18 + + +-- !query 5 +select extract(day from c) from t +-- !query 5 schema +struct +-- !query 5 output +6 + + +-- !query 6 +select extract(dayofweek from c) from t +-- !query 6 schema +struct +-- !query 6 output +6 + + +-- !query 7 +select extract(hour from c) from t +-- !query 7 schema +struct +-- !query 7 output +7 + + +-- !query 8 +select extract(minute from c) from t +-- !query 8 schema +struct +-- !query 8 output +8 + + +-- !query 9 +select extract(second from c) from t +-- !query 9 schema +struct +-- !query 9 output +9 + + +-- !query 10 +select extract(not_supported from c) from t +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.catalyst.parser.ParseException + +Literals of type 'NOT_SUPPORTED' are currently not supported.(line 1, pos 7) + +== SQL == +select extract(not_supported from c) from t +-------^^^ diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index c1abc6dff754b..581aa1754ce14 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 27 -- !query 0 @@ -241,3 +241,12 @@ where b.z != b.z struct<1:int> -- !query 25 output + + +-- !query 26 +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) +-- !query 26 schema +struct +-- !query 26 output +1.0 1.0 3 diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 581dddc89d0bb..3d49323751a10 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 30 -- !query 0 @@ -129,7 +129,7 @@ select to_json() struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json. Expected: one of 1, 2 and 3; Found: 0; line 1 pos 7 +Invalid number of arguments for function to_json. Expected: one of 1 and 2; Found: 0; line 1 pos 7 -- !query 13 @@ -183,7 +183,7 @@ select from_json('{"a":1}', 1) struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Expected a string literal instead of 1;; line 1 pos 7 +Schema should be specified in DDL format as a string literal or output of the schema_of_json function instead of 1;; line 1 pos 7 -- !query 18 @@ -225,7 +225,7 @@ select from_json() struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function from_json. Expected: one of 2, 3 and 4; Found: 0; line 1 pos 7 +Invalid number of arguments for function from_json. Expected: one of 2 and 3; Found: 0; line 1 pos 7 -- !query 22 @@ -258,3 +258,35 @@ DROP VIEW IF EXISTS jsonTable struct<> -- !query 25 output + + +-- !query 26 +select from_json('{"a":1, "b":2}', 'map') +-- !query 26 schema +struct> +-- !query 26 output +{"a":1,"b":2} + + +-- !query 27 +select from_json('{"a":1, "b":"2"}', 'struct') +-- !query 27 schema +struct> +-- !query 27 output +{"a":1,"b":"2"} + + +-- !query 28 +select schema_of_json('{"c1":0, "c2":[1]}') +-- !query 28 schema +struct +-- !query 28 output +struct> + + +-- !query 29 +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')) +-- !query 29 schema +struct>> +-- !query 29 output +{"c1":[1,2,3]} diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out new file mode 100644 index 0000000000000..922d8b9f9152c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -0,0 +1,224 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 15 + + +-- !query 0 +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 2 schema +struct +-- !query 2 output +2012 15000 20000 +2013 48000 30000 + + +-- !query 3 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 3 schema +struct +-- !query 3 output +Java 20000 30000 +dotNET 15000 48000 + + +-- !query 4 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 4 schema +struct +-- !query 4 output +2012 15000 7500.0 20000 20000.0 +2013 48000 48000.0 30000 30000.0 + + +-- !query 5 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 5 schema +struct +-- !query 5 output +63000 50000 + + +-- !query 6 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +) +-- !query 6 schema +struct +-- !query 6 output +63000 2012 50000 2012 + + +-- !query 7 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +) +-- !query 7 schema +struct +-- !query 7 output +Java 2012 20000 NULL +Java 2013 NULL 30000 +dotNET 2012 15000 NULL +dotNET 2013 NULL 48000 + + +-- !query 8 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +) +-- !query 8 schema +struct +-- !query 8 output +2012 15000 1 20000 1 +2013 48000 2 30000 2 + + +-- !query 9 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +) +-- !query 9 schema +struct +-- !query 9 output +2012 15000 20000 +2013 96000 60000 + + +-- !query 10 +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +) +-- !query 10 schema +struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> +-- !query 10 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 11 +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +) +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, but 'coursesales.`earnings`' did not appear in any aggregate function.; + + +-- !query 12 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 + + +-- !query 13 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +) +-- !query 13 schema +struct +-- !query 13 output +2012 15000 7501.0 20000 20001.0 +2013 48000 48001.0 30000 30001.0 + + +-- !query 14 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.; diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index 975bb06124744..abeb7e18f031e 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -178,6 +178,8 @@ struct -- !query 14 output showdb show_t1 false Partition Values: [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t1/c=Us/d=1 +Created Time [not included in comparison] +Last Access [not included in comparison] -- !query 15 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out new file mode 100644 index 0000000000000..a16e98af9a417 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out @@ -0,0 +1,54 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 4 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +-- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 1 schema +struct +-- !query 1 output +NULL 1 + + +-- !query 2 +-- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 2 schema +struct +-- !query 2 output + + + +-- !query 3 +-- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 3 schema +struct +-- !query 3 output +4 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out new file mode 100644 index 0000000000000..aa5f64b8ebf55 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out @@ -0,0 +1,134 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 9 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, null), + (0, 1.0), + (2, 3.0), + (4, null) + AS s(c, d) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +-- Case 1 + -- (subquery is empty -> row is returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE d > 5.0) -- Matches no rows +-- !query 2 schema +struct +-- !query 2 output +2 3 +4 5 +NULL 1 +NULL NULL + + +-- !query 3 +-- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NULL AND d IS NULL) -- Matches only (null, null) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +-- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +-- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 6 schema +struct +-- !query 6 output +NULL 1 + + +-- !query 7 +-- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 7 schema +struct +-- !query 7 output + + + +-- !query 8 +-- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 8 schema +struct +-- !query 8 output +4 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out new file mode 100644 index 0000000000000..446447e890449 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out @@ -0,0 +1,69 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 5 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +-- Uncorrelated NOT IN Subquery test cases + -- Case 1 (not possible to write a literal with no rows, so we ignore it.) + -- (empty subquery -> all rows returned) + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (null) +-- !query 1 schema +struct +-- !query 1 output + + + +-- !query 2 +-- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (2) +-- !query 2 schema +struct +-- !query 2 output + + + +-- !query 3 +-- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (2) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (6) +-- !query 4 schema +struct +-- !query 4 output +2 3 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out new file mode 100644 index 0000000000000..f58ebeacc2872 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out @@ -0,0 +1,149 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (6, 7.0) + AS s(c, d) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +-- Uncorrelated NOT IN Subquery test cases + -- Case 1 + -- (empty subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d > 10.0) -- (empty subquery) +-- !query 2 schema +struct +-- !query 2 output +2 3 +4 5 +NULL 1 + + +-- !query 3 +-- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = 1.0) -- Only matches (null, 1.0) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +-- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +-- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 7.0) -- Matches (6, 7.0) +-- !query 6 schema +struct +-- !query 6 output +2 3 + + +-- !query 7 +-- Correlated NOT IN subquery test cases + -- Case 2->1 + -- (subquery had nulls but they are removed by correlated subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 7 schema +struct +-- !query 7 output +2 3 +4 5 +NULL 1 + + +-- !query 8 +-- Case 3->1 + -- (probe column is null but subquery returns no rows -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 8 schema +struct +-- !query 8 output +NULL 1 + + +-- !query 9 +-- Case 4->1 + -- (probe column matches row which is filtered out by correlated subquery -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 9 schema +struct +-- !query 9 output +2 3 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out new file mode 100644 index 0000000000000..b23a62dacef7c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out @@ -0,0 +1,90 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +SELECT array_join(array(true, false), ', ') +-- !query 0 schema +struct +-- !query 0 output +true, false + + +-- !query 1 +SELECT array_join(array(2Y, 1Y), ', ') +-- !query 1 schema +struct +-- !query 1 output +2, 1 + + +-- !query 2 +SELECT array_join(array(2S, 1S), ', ') +-- !query 2 schema +struct +-- !query 2 output +2, 1 + + +-- !query 3 +SELECT array_join(array(2, 1), ', ') +-- !query 3 schema +struct +-- !query 3 output +2, 1 + + +-- !query 4 +SELECT array_join(array(2L, 1L), ', ') +-- !query 4 schema +struct +-- !query 4 output +2, 1 + + +-- !query 5 +SELECT array_join(array(9223372036854775809, 9223372036854775808), ', ') +-- !query 5 schema +struct +-- !query 5 output +9223372036854775809, 9223372036854775808 + + +-- !query 6 +SELECT array_join(array(2.0D, 1.0D), ', ') +-- !query 6 schema +struct +-- !query 6 output +2.0, 1.0 + + +-- !query 7 +SELECT array_join(array(float(2.0), float(1.0)), ', ') +-- !query 7 schema +struct +-- !query 7 output +2.0, 1.0 + + +-- !query 8 +SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', ') +-- !query 8 schema +struct +-- !query 8 output +2016-03-14, 2016-03-13 + + +-- !query 9 +SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', ') +-- !query 9 schema +struct +-- !query 9 output +2016-11-15 20:54:00, 2016-11-12 20:54:00 + + +-- !query 10 +SELECT array_join(array('a', 'b'), ', ') +-- !query 10 schema +struct +-- !query 10 output +a, b diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index 62befc5ca0f15..be637b66abc86 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 14 -- !query 0 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index 6bfdb84548d4d..cbf44548b3cce 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 36 +-- Number of queries: 40 -- !query 0 @@ -114,190 +114,222 @@ struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.00000000000000000 -- !query 13 -select (5e36 + 0.1) + 5e36 +select 2.35E10 * 1.0 -- !query 13 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)> -- !query 13 output -NULL +23500000000 -- !query 14 -select (-4e36 - 0.1) - 7e36 +select (5e36 + 0.1) + 5e36 -- !query 14 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 14 output NULL -- !query 15 -select 12345678901234567890.0 * 12345678901234567890.0 +select (-4e36 - 0.1) - 7e36 -- !query 15 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 15 output NULL -- !query 16 -select 1e35 / 0.1 +select 12345678901234567890.0 * 12345678901234567890.0 -- !query 16 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> -- !query 16 output NULL -- !query 17 -select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +select 1e35 / 0.1 -- !query 17 schema -struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)> +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> -- !query 17 output -10012345678912345678912345678911.246907 +NULL -- !query 18 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 1.2345678901234567890E30 * 1.2345678901234567890E25 -- !query 18 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)> -- !query 18 output -138698367904130467.654320988515622621 +NULL -- !query 19 -select 12345678912345.123456789123 / 0.000000012345678 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 -- !query 19 schema -struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)> +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)> -- !query 19 output -1000000073899961059796.725866332 +10012345678912345678912345678911.246907 -- !query 20 -set spark.sql.decimalOperations.allowPrecisionLoss=false +select 123456789123456789.1234567890 * 1.123456789123456789 -- !query 20 schema -struct +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> -- !query 20 output -spark.sql.decimalOperations.allowPrecisionLoss false +138698367904130467.654320988515622621 -- !query 21 -select id, a+b, a-b, a*b, a/b from decimals_test order by id +select 12345678912345.123456789123 / 0.000000012345678 -- !query 21 schema -struct +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)> -- !query 21 output -1 1099 -899 NULL 0.1001001001001001 -2 24690.246 0 NULL 1 -3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 -4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 +1000000073899961059796.725866332 -- !query 22 -select id, a*10, b/10 from decimals_test order by id +set spark.sql.decimalOperations.allowPrecisionLoss=false -- !query 22 schema -struct +struct -- !query 22 output -1 1000 99.9 -2 123451.23 1234.5123 -3 1.234567891011 123.41 -4 1234567891234567890 0.1123456789123456789 +spark.sql.decimalOperations.allowPrecisionLoss false -- !query 23 -select 10.3 * 3.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 23 schema -struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +struct -- !query 23 output -30.9 +1 1099 -899 NULL 0.1001001001001001 +2 24690.246 0 NULL 1 +3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 +4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 -- !query 24 -select 10.3000 * 3.0 +select id, a*10, b/10 from decimals_test order by id -- !query 24 schema -struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +struct -- !query 24 output -30.9 +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.1123456789123456789 -- !query 25 -select 10.30000 * 30.0 +select 10.3 * 3.0 -- !query 25 schema -struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 25 output -309 +30.9 -- !query 26 -select 10.300000000000000000 * 3.000000000000000000 +select 10.3000 * 3.0 -- !query 26 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 26 output 30.9 -- !query 27 -select 10.300000000000000000 * 3.0000000000000000000 +select 10.30000 * 30.0 -- !query 27 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> -- !query 27 output -NULL +309 -- !query 28 -select (5e36 + 0.1) + 5e36 +select 10.300000000000000000 * 3.000000000000000000 -- !query 28 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> -- !query 28 output -NULL +30.9 -- !query 29 -select (-4e36 - 0.1) - 7e36 +select 10.300000000000000000 * 3.0000000000000000000 -- !query 29 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> -- !query 29 output NULL -- !query 30 -select 12345678901234567890.0 * 12345678901234567890.0 +select 2.35E10 * 1.0 -- !query 30 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)> -- !query 30 output -NULL +23500000000 -- !query 31 -select 1e35 / 0.1 +select (5e36 + 0.1) + 5e36 -- !query 31 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 31 output NULL -- !query 32 -select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +select (-4e36 - 0.1) - 7e36 -- !query 32 schema -struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)> +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 32 output NULL -- !query 33 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 12345678901234567890.0 * 12345678901234567890.0 -- !query 33 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> -- !query 33 output NULL -- !query 34 -select 12345678912345.123456789123 / 0.000000012345678 +select 1e35 / 0.1 -- !query 34 schema -struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)> +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> -- !query 34 output NULL -- !query 35 -drop table decimals_test +select 1.2345678901234567890E30 * 1.2345678901234567890E25 -- !query 35 schema -struct<> +struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)> -- !query 35 output +NULL + + +-- !query 36 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +-- !query 36 schema +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)> +-- !query 36 output +NULL + + +-- !query 37 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 37 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +-- !query 37 output +NULL + + +-- !query 38 +select 12345678912345.123456789123 / 0.000000012345678 +-- !query 38 schema +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)> +-- !query 38 output +NULL + + +-- !query 39 +drop table decimals_test +-- !query 39 schema +struct<> +-- !query 39 output diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out new file mode 100644 index 0000000000000..d352b7284ae87 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -0,0 +1,143 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps +-- !query 1 schema +struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map,map>,string_int_map:map,int_string_map:map> +-- !query 1 output +{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {{"a":1}:{"b":2},{"c":3}:{"d":4}} {"a":1,"c":2} {1:"a",2:"c"} + + +-- !query 2 +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps +-- !query 2 schema +struct,si_map:map,ib_map:map,df_map:map,std_map:map,tst_map:map,sti_map:map,istt_map:map> +-- !query 2 output +{1:2,3:4} {1:2,7:8} {4:6,8:9} {3.0:4.0,9.223372036854776E18:9.223372036854776E18} {"2016-03-12":"2016-03-11","a":"b"} {"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"} {"7":"8","a":"b"} {1:"a",3:"4"} + + +-- !query 3 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`map_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,map>]; line 2 pos 4 + + +-- !query 4 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`boolean_map1`, various_maps.`int_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map]; line 2 pos 4 + + +-- !query 5 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,struct>]; line 2 pos 4 + + +-- !query 6 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,array>]; line 2 pos 4 + + +-- !query 7 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,struct>]; line 2 pos 4 diff --git a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out new file mode 100644 index 0000000000000..d7d009a64bf84 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out @@ -0,0 +1,93 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 3 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (101, 1, 1, 1), + (201, 2, 1, 1), + (301, 3, 1, 1), + (401, 4, 1, 11), + (501, 5, 1, null), + (601, 6, null, 1), + (701, 6, null, null), + (102, 1, 2, 2), + (202, 2, 1, 2), + (302, 3, 2, 1), + (402, 4, 2, 12), + (502, 5, 2, null), + (602, 6, null, 2), + (702, 6, null, null), + (103, 1, 3, 3), + (203, 2, 1, 3), + (303, 3, 3, 1), + (403, 4, 3, 13), + (503, 5, 3, null), + (603, 6, null, 3), + (703, 6, null, null), + (104, 1, 4, 4), + (204, 2, 1, 4), + (304, 3, 4, 1), + (404, 4, 4, 14), + (504, 5, 4, null), + (604, 6, null, 4), + (704, 6, null, null), + (800, 7, 1, 1) +as t1(id, px, y, x) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), + regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), + regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) +from t1 group by px order by px +-- !query 1 schema +struct +-- !query 1 output +1 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 0.0 1.0 5.0 5.0 5.0 2.5 2.5 4 +2 1.25 0.0 NULL 0.0 0.0 4 0.0 1.0 1.0 5.0 0.0 0.0 2.5 1.0 4 +3 0.0 1.25 NULL 0.0 0.0 4 NULL NULL NULL 0.0 5.0 0.0 1.0 2.5 4 +4 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 -10.0 1.0 5.0 5.0 5.0 12.5 2.5 4 +5 NULL 1.25 NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 +6 1.25 NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 +7 0.0 0.0 NaN NaN 0.0 1 NULL NULL NULL 0.0 0.0 0.0 1.0 1.0 1 + + +-- !query 2 +select id, regr_count(y,x) over (partition by px) from t1 order by id +-- !query 2 schema +struct +-- !query 2 output +101 4 +102 4 +103 4 +104 4 +201 4 +202 4 +203 4 +204 4 +301 4 +302 4 +303 4 +304 4 +401 4 +402 4 +403 4 +404 4 +501 0 +502 0 +503 0 +504 0 +601 0 +602 0 +603 0 +604 0 +701 0 +702 0 +703 0 +704 0 +800 1 diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index d123b7fdbe0cf..b023df825d814 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 16 -- !query 0 @@ -105,23 +105,29 @@ struct -- !query 9 -DROP VIEW IF EXISTS t1 +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1 -- !query 9 schema -struct<> +struct,str:string> -- !query 9 output - +{1:2,3:null} 1 +{1:2} str -- !query 10 -DROP VIEW IF EXISTS t2 +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1 -- !query 10 schema -struct<> +struct,str:string> -- !query 10 output - +[1,2,3,null] 1 +[1,2] str -- !query 11 -DROP VIEW IF EXISTS p1 +DROP VIEW IF EXISTS t1 -- !query 11 schema struct<> -- !query 11 output @@ -129,7 +135,7 @@ struct<> -- !query 12 -DROP VIEW IF EXISTS p2 +DROP VIEW IF EXISTS t2 -- !query 12 schema struct<> -- !query 12 output @@ -137,8 +143,24 @@ struct<> -- !query 13 -DROP VIEW IF EXISTS p3 +DROP VIEW IF EXISTS p1 -- !query 13 schema struct<> -- !query 13 output + + +-- !query 14 +DROP VIEW IF EXISTS p2 +-- !query 14 schema +struct<> +-- !query 14 output + + + +-- !query 15 +DROP VIEW IF EXISTS p3 +-- !query 15 schema +struct<> +-- !query 15 output + diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata new file mode 100644 index 0000000000000..d6be7fbffa9b7 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata @@ -0,0 +1 @@ +{"id":"549eeb1a-d762-420c-bb44-3fd6d73a5268"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 new file mode 100644 index 0000000000000..43db49d052894 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531172902041,"conf":{"spark.sql.shuffle.partitions":"10","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 new file mode 100644 index 0000000000000..8cc898e81017f --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":10000,"batchTimestampMs":1531172902217,"conf":{"spark.sql.shuffle.partitions":"10","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/test-data/comments-whitespaces.csv b/sql/core/src/test/resources/test-data/comments-whitespaces.csv new file mode 100644 index 0000000000000..2737978f83a5e --- /dev/null +++ b/sql/core/src/test/resources/test-data/comments-whitespaces.csv @@ -0,0 +1,8 @@ +# The file contains comments, whitespaces and empty lines +colA +# empty line + +# the line with a few whitespaces + +# int value with leading and trailing whitespaces + "a" diff --git a/sql/core/src/test/resources/test-data/parquet-1217.parquet b/sql/core/src/test/resources/test-data/parquet-1217.parquet new file mode 100644 index 0000000000000..eb2dc4f799070 Binary files /dev/null and b/sql/core/src/test/resources/test-data/parquet-1217.parquet differ diff --git a/sql/core/src/test/resources/test-data/utf16LE.json b/sql/core/src/test/resources/test-data/utf16LE.json new file mode 100644 index 0000000000000..ce4117fd299df Binary files /dev/null and b/sql/core/src/test/resources/test-data/utf16LE.json differ diff --git a/sql/core/src/test/resources/test-data/utf16WithBOM.json b/sql/core/src/test/resources/test-data/utf16WithBOM.json new file mode 100644 index 0000000000000..cf4d29328b860 Binary files /dev/null and b/sql/core/src/test/resources/test-data/utf16WithBOM.json differ diff --git a/sql/core/src/test/resources/test-data/utf32BEWithBOM.json b/sql/core/src/test/resources/test-data/utf32BEWithBOM.json new file mode 100644 index 0000000000000..6c7733c577872 Binary files /dev/null and b/sql/core/src/test/resources/test-data/utf32BEWithBOM.json differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 137c5bea2abb9..d635912cf7205 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -279,4 +280,16 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } } + + test("SPARK-24013: unneeded compress can cause performance issues with sorted input") { + val buffer = new PercentileDigest(1.0D / ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY) + var compressCounts = 0 + (1 to 10000000).foreach { i => + buffer.add(i) + if (buffer.isCompressed) compressCounts += 1 + } + assert(compressCounts > 0) + buffer.quantileSummaries + assert(buffer.isCompressed) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 669e5f2bf4e65..60c73df88896b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,12 +22,12 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.CleanerListener +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -52,7 +52,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head @@ -78,29 +78,10 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { plan.collect { case InMemoryTableScanExec(_, _, relation) => - getNumInMemoryTablesRecursively(relation.child) + 1 + getNumInMemoryTablesRecursively(relation.cachedPlan) + 1 }.sum } - test("withColumn doesn't invalidate cached dataframe") { - var evalCount = 0 - val myUDF = udf((x: String) => { evalCount += 1; "result" }) - val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) - df.cache() - - df.collect() - assert(evalCount === 1) - - df.collect() - assert(evalCount === 1) - - val df2 = df.withColumn("newColumn", lit(1)) - df2.collect() - - // We should not reevaluate the cached dataframe - assert(evalCount === 1) - } - test("cache temp table") { withTempView("tempTable") { testData.select('key).createOrReplaceTempView("tempTable") @@ -200,7 +181,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { spark.table("testData").queryExecution.withCachedData.collect { - case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r + case r: InMemoryRelation if r.cachedPlan.isInstanceOf[InMemoryTableScanExec] => r }.size } @@ -367,12 +348,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId2 @@ -794,4 +775,94 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } } } + + private def checkIfNoJobTriggered[T](f: => T): T = { + var numJobTrigered = 0 + val jobListener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numJobTrigered += 1 + } + } + sparkContext.addSparkListener(jobListener) + try { + val result = f + sparkContext.listenerBus.waitUntilEmpty(10000L) + assert(numJobTrigered === 0) + result + } finally { + sparkContext.removeSparkListener(jobListener) + } + } + + test("SPARK-23880 table cache should be lazy and don't trigger any jobs") { + val cachedData = checkIfNoJobTriggered { + spark.range(1002).filter('id > 1000).orderBy('id.desc).cache() + } + assert(cachedData.collect === Seq(1001)) + } + + test("SPARK-24596 Non-cascading Cache Invalidation - uncache temporary view") { + withTempView("t1", "t2") { + sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1") + sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("UNCACHE TABLE t1") + assert(!spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - drop temporary view") { + withTempView("t1", "t2") { + sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1") + sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("DROP VIEW t1") + assert(spark.catalog.isCached("t2")) + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - drop persistent view") { + withTable("t") { + spark.range(1, 10).toDF("key").withColumn("value", 'key * 2) + .write.format("json").saveAsTable("t") + withView("t1") { + withTempView("t2") { + sql("CREATE VIEW t1 AS SELECT * FROM t WHERE key > 1") + + sql("CACHE TABLE t1") + sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("DROP VIEW t1") + assert(!spark.catalog.isCached("t2")) + } + } + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - uncache table") { + withTable("t") { + spark.range(1, 10).toDF("key").withColumn("value", 'key * 2) + .write.format("json").saveAsTable("t") + withTempView("t1", "t2") { + sql("CACHE TABLE t") + sql("CACHE TABLE t1 AS SELECT * FROM t WHERE key > 1") + sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t")) + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("UNCACHE TABLE t") + assert(!spark.catalog.isCached("t")) + assert(!spark.catalog.isCached("t1")) + assert(!spark.catalog.isCached("t2")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 7c45be21961d3..2182bd7eadd63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.util.Locale + +import scala.collection.JavaConverters._ + import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ @@ -390,11 +394,67 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer(df.filter($"b".isin("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + // Auto casting should work with mixture of different types in collections + checkAnswer(df.filter($"a".isin(1.toShort, "2")), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin("3", 2.toLong)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin(3, "1")), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - intercept[AnalysisException] { + val e = intercept[AnalysisException] { df2.filter($"a".isin($"b")) } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isInCollection: Scala Collection") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b"))) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isInCollection: Java Collection") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b").asJava)) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } } test("&&") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index 949505e449fd7..276496be3d62c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -39,7 +39,9 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { def computeChiSquareTest(): Double = { val n = 10000 // Trigger a sort - val data = spark.range(0, n, 1, 1).sort('id.desc) + // Range has range partitioning in its output now. To have a range shuffle, we + // need to run a repartition first. + val data = spark.range(0, n, 1, 1).repartition(10).sort('id.desc) .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() // Compute histogram for the number of records per partition post sort diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7776e36702ad..f495a949ebc5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql import scala.util.Random -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.scalatest.Matchers.the + import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -36,6 +36,8 @@ case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Doub class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ + val absTol = 1e-8 + test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -416,7 +418,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("moments") { - val absTol = 1e-8 val sparkVariance = testData2.agg(variance('a)) checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) @@ -686,4 +687,34 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-21896: Window functions inside aggregate functions") { + def checkWindowError(df: => DataFrame): Unit = { + val thrownException = the [AnalysisException] thrownBy { + df.queryExecution.analyzed + } + assert(thrownException.message.contains("not allowed to use a window function")) + } + + checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a))))) + checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a))))) + checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b))))) + checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('a))))) + checkWindowError( + testData2.groupBy('a).agg(sum('b).as("s"), max(count("*").over())).where('s === 3)) + checkAnswer( + testData2.groupBy('a).agg(max('b), sum('b).as("s"), count("*").over()).where('s === 3), + Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil) + + checkWindowError(sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2")) + checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2")) + checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a")) + checkWindowError(sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a")) + checkWindowError( + sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3")) + checkAnswer( + sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"), + Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 25e5cd60dd236..d4615714cff03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.sql import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import java.util.TimeZone import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -62,6 +65,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(row.getMap[Int, String](0) === Map(2 -> "a")) } + test("map with arrays") { + val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v") + val expectedType = MapType(IntegerType, StringType, valueContainsNull = true) + val row = df1.select(map_from_arrays($"k", $"v")).first() + assert(row.schema(0).dataType === expectedType) + assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b")) + checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b")))) + + val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v") + checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b")))) + + val df3 = Seq((null, null)).toDF("k", "v") + checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null))) + + val df4 = Seq((1, "a")).toDF("k", "v") + intercept[AnalysisException] { + df4.select(map_from_arrays($"k", $"v")) + } + + val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") + intercept[RuntimeException] { + df5.select(map_from_arrays($"k", $"v")).collect + } + + val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v") + intercept[RuntimeException] { + df6.select(map_from_arrays($"k", $"v")).collect + } + } + test("struct with column name") { val df = Seq((1, "str")).toDF("a", "b") val row = df.select(struct("a", "b")).first() @@ -276,7 +309,114 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("sort_array function") { + test("mask functions") { + val df = Seq("TestString-123", "", null).toDF("a") + checkAnswer(df.select(mask($"a")), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_first_n($"a", 4)), Seq(Row("XxxxString-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_last_n($"a", 4)), Seq(Row("TestString-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_first_n($"a", 4)), + Seq(Row("TestXxxxxx-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_last_n($"a", 4)), + Seq(Row("XxxxXxxxxx-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_hash($"a")), + Seq(Row("dd78d68ad1b23bde126812482dd70ac6"), + Row("d41d8cd98f00b204e9800998ecf8427e"), + Row(null))) + + checkAnswer(df.select(mask($"a", "U", "l", "#")), + Seq(Row("UlllUlllll-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_first_n($"a", 4, "U", "l", "#")), + Seq(Row("UlllString-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_last_n($"a", 4, "U", "l", "#")), + Seq(Row("TestString-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_first_n($"a", 4, "U", "l", "#")), + Seq(Row("TestUlllll-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_last_n($"a", 4, "U", "l", "#")), + Seq(Row("UlllUlllll-123"), Row(""), Row(null))) + + checkAnswer( + df.selectExpr("mask(a)", "mask(a, 'U')", "mask(a, 'U', 'l')", "mask(a, 'U', 'l', '#')"), + Seq(Row("XxxxXxxxxx-nnn", "UxxxUxxxxx-nnn", "UlllUlllll-nnn", "UlllUlllll-###"), + Row("", "", "", ""), + Row(null, null, null, null))) + checkAnswer(sql("select mask(null)"), Row(null)) + checkAnswer(sql("select mask('AAaa11', null, null, null)"), Row("XXxxnn")) + intercept[AnalysisException] { + checkAnswer(df.selectExpr("mask(a, a)"), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) + } + + checkAnswer( + df.selectExpr( + "mask_first_n(a)", + "mask_first_n(a, 6)", + "mask_first_n(a, 6, 'U')", + "mask_first_n(a, 6, 'U', 'l')", + "mask_first_n(a, 6, 'U', 'l', '#')"), + Seq(Row("XxxxString-123", "XxxxXxring-123", "UxxxUxring-123", "UlllUlring-123", + "UlllUlring-123"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_first_n(null)"), Row(null)) + checkAnswer(sql("select mask_first_n('A1aA1a', null, null, null, null)"), Row("XnxX1a")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_first_n('A1aA1a', id)"), Row("XnxX1a")) + } + + checkAnswer( + df.selectExpr( + "mask_last_n(a)", + "mask_last_n(a, 6)", + "mask_last_n(a, 6, 'U')", + "mask_last_n(a, 6, 'U', 'l')", + "mask_last_n(a, 6, 'U', 'l', '#')"), + Seq(Row("TestString-nnn", "TestStrixx-nnn", "TestStrixx-nnn", "TestStrill-nnn", + "TestStrill-###"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_last_n(null)"), Row(null)) + checkAnswer(sql("select mask_last_n('A1aA1a', null, null, null, null)"), Row("A1xXnx")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_last_n('A1aA1a', id)"), Row("A1xXnx")) + } + + checkAnswer( + df.selectExpr( + "mask_show_first_n(a)", + "mask_show_first_n(a, 6)", + "mask_show_first_n(a, 6, 'U')", + "mask_show_first_n(a, 6, 'U', 'l')", + "mask_show_first_n(a, 6, 'U', 'l', '#')"), + Seq(Row("TestXxxxxx-nnn", "TestStxxxx-nnn", "TestStxxxx-nnn", "TestStllll-nnn", + "TestStllll-###"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_show_first_n(null)"), Row(null)) + checkAnswer(sql("select mask_show_first_n('A1aA1a', null, null, null, null)"), Row("A1aAnx")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_show_first_n('A1aA1a', id)"), Row("A1aAnx")) + } + + checkAnswer( + df.selectExpr( + "mask_show_last_n(a)", + "mask_show_last_n(a, 6)", + "mask_show_last_n(a, 6, 'U')", + "mask_show_last_n(a, 6, 'U', 'l')", + "mask_show_last_n(a, 6, 'U', 'l', '#')"), + Seq(Row("XxxxXxxxxx-123", "XxxxXxxxng-123", "UxxxUxxxng-123", "UlllUlllng-123", + "UlllUlllng-123"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_show_last_n(null)"), Row(null)) + checkAnswer(sql("select mask_show_last_n('A1aA1a', null, null, null, null)"), Row("XnaA1a")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_show_last_n('A1aA1a', id)"), Row("XnaA1a")) + } + + checkAnswer(sql("select mask_hash(null)"), Row(null)) + } + + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), (Array.empty[Int], Array.empty[String]), @@ -286,28 +426,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(sort_array($"a"), sort_array($"b")), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.select(sort_array($"a", false), sort_array($"b", false)), Seq( Row(Seq(3, 2, 1), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a)", "sort_array(b)"), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a, true)", "sort_array(b, false)"), Seq( Row(Seq(1, 2, 3), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) @@ -324,40 +464,137 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(intercept[AnalysisException] { df3.selectExpr("sort_array(a)").collect() }.getMessage().contains("only supports array input")) + + checkAnswer( + df.select(array_sort($"a"), array_sort($"b")), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_sort(a)", "array_sort(b)"), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + + checkAnswer( + df2.selectExpr("array_sort(a)"), + Seq(Row(Seq[Seq[Int]](Seq(1), Seq(2), Seq(2, 4), null))) + ) + + assert(intercept[AnalysisException] { + df3.selectExpr("array_sort(a)").collect() + }.getMessage().contains("only supports array input")) } - test("array size function") { + def testSizeOfArray(sizeOfNull: Any): Unit = { val df = Seq( (Seq[Int](1, 2), "x"), (Seq[Int](), "y"), (Seq[Int](1, 2, 3), "z"), (null, "empty") ).toDF("a", "b") - checkAnswer( - df.select(size($"a")), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) + + checkAnswer(df.select(size($"a")), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("cardinality(a)"), Seq(Row(2L), Row(0L), Row(3L), Row(sizeOfNull))) } - test("map size function") { + test("array size function - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfArray(sizeOfNull = -1) + } + } + + test("array size function") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSizeOfArray(sizeOfNull = null) + } + } + + test("dataframe arrays_zip function") { + val df1 = Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6))).toDF("val1", "val2") + val df2 = Seq((Seq("a", "b"), Seq(true, false), Seq(10, 11))).toDF("val1", "val2", "val3") + val df3 = Seq((Seq("a", "b"), Seq(4, 5, 6))).toDF("val1", "val2") + val df4 = Seq((Seq("a", "b", null), Seq(4L))).toDF("val1", "val2") + val df5 = Seq((Seq(-1), Seq(null), Seq(), Seq(null, null))).toDF("val1", "val2", "val3", "val4") + val df6 = Seq((Seq(192.toByte, 256.toByte), Seq(1.1), Seq(), Seq(null, null))) + .toDF("v1", "v2", "v3", "v4") + val df7 = Seq((Seq(Seq(1, 2, 3), Seq(4, 5)), Seq(1.1, 2.2))).toDF("v1", "v2") + val df8 = Seq((Seq(Array[Byte](1.toByte, 5.toByte)), Seq(null))).toDF("v1", "v2") + + val expectedValue1 = Row(Seq(Row(9001, 4), Row(9002, 5), Row(9003, 6))) + checkAnswer(df1.select(arrays_zip($"val1", $"val2")), expectedValue1) + checkAnswer(df1.selectExpr("arrays_zip(val1, val2)"), expectedValue1) + + val expectedValue2 = Row(Seq(Row("a", true, 10), Row("b", false, 11))) + checkAnswer(df2.select(arrays_zip($"val1", $"val2", $"val3")), expectedValue2) + checkAnswer(df2.selectExpr("arrays_zip(val1, val2, val3)"), expectedValue2) + + val expectedValue3 = Row(Seq(Row("a", 4), Row("b", 5), Row(null, 6))) + checkAnswer(df3.select(arrays_zip($"val1", $"val2")), expectedValue3) + checkAnswer(df3.selectExpr("arrays_zip(val1, val2)"), expectedValue3) + + val expectedValue4 = Row(Seq(Row("a", 4L), Row("b", null), Row(null, null))) + checkAnswer(df4.select(arrays_zip($"val1", $"val2")), expectedValue4) + checkAnswer(df4.selectExpr("arrays_zip(val1, val2)"), expectedValue4) + + val expectedValue5 = Row(Seq(Row(-1, null, null, null), Row(null, null, null, null))) + checkAnswer(df5.select(arrays_zip($"val1", $"val2", $"val3", $"val4")), expectedValue5) + checkAnswer(df5.selectExpr("arrays_zip(val1, val2, val3, val4)"), expectedValue5) + + val expectedValue6 = Row(Seq( + Row(192.toByte, 1.1, null, null), Row(256.toByte, null, null, null))) + checkAnswer(df6.select(arrays_zip($"v1", $"v2", $"v3", $"v4")), expectedValue6) + checkAnswer(df6.selectExpr("arrays_zip(v1, v2, v3, v4)"), expectedValue6) + + val expectedValue7 = Row(Seq( + Row(Seq(1, 2, 3), 1.1), Row(Seq(4, 5), 2.2))) + checkAnswer(df7.select(arrays_zip($"v1", $"v2")), expectedValue7) + checkAnswer(df7.selectExpr("arrays_zip(v1, v2)"), expectedValue7) + + val expectedValue8 = Row(Seq( + Row(Array[Byte](1.toByte, 5.toByte), null))) + checkAnswer(df8.select(arrays_zip($"v1", $"v2")), expectedValue8) + checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8) + } + + test("SPARK-24633: arrays_zip splits input processing correctly") { + Seq("true", "false").foreach { wholestageCodegenEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholestageCodegenEnabled) { + val df = spark.range(1) + val exprs = (0 to 5).map(x => array($"id" + lit(x))) + checkAnswer(df.select(arrays_zip(exprs: _*)), + Row(Seq(Row(0, 1, 2, 3, 4, 5)))) + } + } + } + + def testSizeOfMap(sizeOfNull: Any): Unit = { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"), (Map[Int, Int](), "y"), (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z"), (null, "empty") ).toDF("a", "b") - checkAnswer( - df.select(size($"a")), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) + + checkAnswer(df.select(size($"a")), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + } + + test("map size function - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfMap(sizeOfNull = -1: Int) + } + } + + test("map size function") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSizeOfMap(sizeOfNull = null) + } } test("map_keys/map_values function") { @@ -376,11 +613,183 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("map_entries") { + val dummyFilter = (c: Column) => c.isNotNull || c.isNull + + // Primitive-type elements + val idf = Seq( + Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), + Map[Int, Int](), + null + ).toDF("m") + val iExpected = Seq( + Row(Seq(Row(1, 100), Row(2, 200), Row(3, 300))), + Row(Seq.empty), + Row(null) + ) + + checkAnswer(idf.select(map_entries('m)), iExpected) + checkAnswer(idf.selectExpr("map_entries(m)"), iExpected) + checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected) + checkAnswer( + spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"), + Seq(Row(Seq(Row(1, null), Row(2, null))))) + checkAnswer( + spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"), + Seq(Row(Seq(Row(1, null), Row(2, null))))) + + // Non-primitive-type elements + val sdf = Seq( + Map[String, String]("a" -> "f", "b" -> "o", "c" -> "o"), + Map[String, String]("a" -> null, "b" -> null), + Map[String, String](), + null + ).toDF("m") + val sExpected = Seq( + Row(Seq(Row("a", "f"), Row("b", "o"), Row("c", "o"))), + Row(Seq(Row("a", null), Row("b", null))), + Row(Seq.empty), + Row(null) + ) + + checkAnswer(sdf.select(map_entries('m)), sExpected) + checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) + checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) + } + + test("map_concat function") { + val df1 = Seq( + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)), + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)), + (null, Map[Int, Int](3 -> 300, 4 -> 400)) + ).toDF("map1", "map2") + + val expected1a = Seq( + Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), + Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a) + checkAnswer(df1.select(map_concat('map1, 'map2)), expected1a) + + val expected1b = Seq( + Row(Map(1 -> 100, 2 -> 200)), + Row(Map(1 -> 100, 2 -> 200)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b) + checkAnswer(df1.select(map_concat('map1)), expected1b) + + val df2 = Seq( + ( + Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200), + Map[String, Int]("3" -> 300, "4" -> 400) + ) + ).toDF("map1", "map2") + + val expected2 = Seq(Row(Map())) + + checkAnswer(df2.selectExpr("map_concat()"), expected2) + checkAnswer(df2.select(map_concat()), expected2) + + val df3 = { + val schema = StructType( + StructField("map1", MapType(StringType, IntegerType, true), false) :: + StructField("map2", MapType(StringType, IntegerType, false), false) :: Nil + ) + val data = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, "d" -> 4)) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val expected3 = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4)) + ) + + checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3) + checkAnswer(df3.select(map_concat('map1, 'map2)), expected3) + + val expectedMessage1 = "input to function map_concat should all be the same type" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, map2)").collect() + }.getMessage().contains(expectedMessage1)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, 'map2)).collect() + }.getMessage().contains(expectedMessage1)) + + val expectedMessage2 = "input to function map_concat should all be of type map" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, 12)").collect() + }.getMessage().contains(expectedMessage2)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, lit(12))).collect() + }.getMessage().contains(expectedMessage2)) + } + + test("map_from_entries function") { + def dummyFilter(c: Column): Column = c.isNull || c.isNotNull + val oneRowDF = Seq(3215).toDF("i") + + // Test cases with primitive-type keys and values + val idf = Seq( + Seq((1, 10), (2, 20), (3, 10)), + Seq((1, 10), null, (2, 20)), + Seq.empty, + null + ).toDF("a") + val iExpected = Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 10)), + Row(null), + Row(Map.empty), + Row(null)) + + checkAnswer(idf.select(map_from_entries('a)), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) + checkAnswer(idf.filter(dummyFilter('a)).select(map_from_entries('a)), iExpected) + checkAnswer( + oneRowDF.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)) + .selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + + // Test cases with non-primitive-type keys and values + val sdf = Seq( + Seq(("a", "aa"), ("b", "bb"), ("c", "aa")), + Seq(("a", "aa"), null, ("b", "bb")), + Seq(("a", null), ("b", null)), + Seq.empty, + null + ).toDF("a") + val sExpected = Seq( + Row(Map("a" -> "aa", "b" -> "bb", "c" -> "aa")), + Row(null), + Row(Map("a" -> null, "b" -> null)), + Row(Map.empty), + Row(null)) + + checkAnswer(sdf.select(map_from_entries('a)), sExpected) + checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) + checkAnswer(sdf.filter(dummyFilter('a)).select(map_from_entries('a)), sExpected) + } + test("array contains function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") // Simple test cases checkAnswer( @@ -391,6 +800,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false)) ) + checkAnswer( + df.select(array_contains(df("a"), df("c"))), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, c)"), + Seq(Row(true), Row(false)) + ) // In hive, this errors because null has no type information intercept[AnalysisException] { @@ -413,6 +830,91 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("arrays_overlap function") { + val df = Seq( + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))), + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), None)), + (Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2))) + ).toDF("a", "b") + + val answer = Seq(Row(false), Row(null), Row(true)) + + checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer) + checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer) + + checkAnswer( + Seq((Seq(1, 2, 3), Seq(2.0, 2.5))).toDF("a", "b").selectExpr("arrays_overlap(a, b)"), + Row(true)) + + intercept[AnalysisException] { + sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))") + } + + intercept[AnalysisException] { + sql("select arrays_overlap(null, null)") + } + + intercept[AnalysisException] { + sql("select arrays_overlap(map(1, 2), map(3, 4))") + } + } + + test("slice function") { + val df = Seq( + Seq(1, 2, 3), + Seq(4, 5) + ).toDF("x") + + val answer = Seq(Row(Seq(2, 3)), Row(Seq(5))) + + checkAnswer(df.select(slice(df("x"), 2, 2)), answer) + checkAnswer(df.selectExpr("slice(x, 2, 2)"), answer) + + val answerNegative = Seq(Row(Seq(3)), Row(Seq(5))) + checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative) + checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative) + } + + test("array_join function") { + val df = Seq( + (Seq[String]("a", "b"), ","), + (Seq[String]("a", null, "b"), ","), + (Seq.empty[String], ",") + ).toDF("x", "delimiter") + + checkAnswer( + df.select(array_join(df("x"), ";")), + Seq(Row("a;b"), Row("a;b"), Row("")) + ) + checkAnswer( + df.select(array_join(df("x"), ";", "NULL")), + Seq(Row("a;b"), Row("a;NULL;b"), Row("")) + ) + checkAnswer( + df.selectExpr("array_join(x, delimiter)"), + Seq(Row("a,b"), Row("a,b"), Row(""))) + checkAnswer( + df.selectExpr("array_join(x, delimiter, 'NULL')"), + Seq(Row("a,b"), Row("a,NULL,b"), Row(""))) + + val idf = Seq(Seq(1, 2, 3)).toDF("x") + + checkAnswer( + idf.select(array_join(idf("x"), ", ")), + Seq(Row("1, 2, 3")) + ) + checkAnswer( + idf.selectExpr("array_join(x, ', ')"), + Seq(Row("1, 2, 3")) + ) + intercept[AnalysisException] { + idf.selectExpr("array_join(x, 1)") + } + intercept[AnalysisException] { + idf.selectExpr("array_join(x, ', ', 1)") + } + } + test("array_min function") { val df = Seq( Seq[Option[Int]](Some(1), Some(3), Some(2)), @@ -441,6 +943,59 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("array_max(a)"), answer) } + test("sequence") { + checkAnswer(Seq((-2, 2)).toDF().select(sequence('_1, '_2)), Seq(Row(Array(-2, -1, 0, 1, 2)))) + checkAnswer(Seq((7, 2, -2)).toDF().select(sequence('_1, '_2, '_3)), Seq(Row(Array(7, 5, 3)))) + + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01 00:00:00' as timestamp)" + + ", cast('2018-01-02 00:00:00' as timestamp)" + + ", interval 12 hours)"), + Seq(Row(Array( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))))) + + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("UTC")) { + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01' as date)" + + ", cast('2018-03-01' as date)" + + ", interval 1 month)"), + Seq(Row(Array( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-02-01"), + Date.valueOf("2018-03-01"))))) + } + + // test type coercion + checkAnswer( + Seq((1.toByte, 3L, 1)).toDF().select(sequence('_1, '_2, '_3)), + Seq(Row(Array(1L, 2L, 3L)))) + + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01' as date)" + + ", cast('2018-01-02 00:00:00' as timestamp)" + + ", interval 12 hours)"), + Seq(Row(Array( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))))) + + // test invalid data types + intercept[AnalysisException] { + Seq((true, false)).toDF().selectExpr("sequence(_1, _2)") + } + intercept[AnalysisException] { + Seq((true, false, 42)).toDF().selectExpr("sequence(_1, _2, _3)") + } + intercept[AnalysisException] { + Seq((1, 2, 0.5)).toDF().selectExpr("sequence(_1, _2, _3)") + } + } + test("reverse function") { val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on @@ -537,9 +1092,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array position function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") checkAnswer( df.select(array_position(df("a"), 1)), @@ -549,7 +1104,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(a, 1)"), Seq(Row(1L), Row(0L)) ) - + checkAnswer( + df.selectExpr("array_position(a, c)"), + Seq(Row(1L), Row(0L)) + ) + checkAnswer( + df.select(array_position(df("a"), df("c"))), + Seq(Row(1L), Row(0L)) + ) checkAnswer( df.select(array_position(df("a"), null)), Seq(Row(null), Row(null)) @@ -567,14 +1129,19 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(array(1, null), array(1, null)[0])"), Seq(Row(1L), Row(1L)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") + } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } test("element_at function") { val df = Seq( - (Seq[String]("1", "2", "3")), - (Seq[String](null, "")), - (Seq[String]()) - ).toDF("a") + (Seq[String]("1", "2", "3"), 1), + (Seq[String](null, ""), -1), + (Seq[String](), 2) + ).toDF("a", "b") intercept[Exception] { checkAnswer( @@ -592,6 +1159,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(element_at(df("a"), 4)), Seq(Row(null), Row(null), Row(null)) ) + checkAnswer( + df.select(element_at(df("a"), df("b"))), + Seq(Row("1"), Row(""), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, b)"), + Seq(Row("1"), Row(""), Row(null)) + ) checkAnswer( df.select(element_at(df("a"), 1)), @@ -615,6 +1190,64 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("element_at(a, -1)"), Seq(Row("3"), Row(""), Row(null)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)") + } + assert(e.message.contains( + "argument 1 requires (array or map) type, however, '`_1`' is of string type")) + } + + test("array_union functions") { + val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(1, 2, 3, 4)) + checkAnswer(df1.select(array_union($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_union(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array(-5, 4, -3, 2, -1))).toDF("a", "b") + val ans2 = Row(Seq(1, 2, null, 4, 5, -5, -3, -1)) + checkAnswer(df2.select(array_union($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_union(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 3L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(1L, 2L, 3L, 4L)) + checkAnswer(df3.select(array_union($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_union(a, b)"), ans3) + + val df4 = Seq((Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array(-5L, 4L, -3L, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkAnswer(df4.select(array_union($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_union(a, b)"), ans4) + + val df5 = Seq((Array("b", "a", "c"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq("b", "a", "c", null, "g")) + checkAnswer(df5.select(array_union($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_union(a, b)"), ans5) + + val df6 = Seq((null, Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df6.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df6.selectExpr("array_union(a, b)") + } + + val df7 = Seq((null, null)).toDF("a", "b") + intercept[AnalysisException] { + df7.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df7.selectExpr("array_union(a, b)") + } + + val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df8.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df8.selectExpr("array_union(a, b)") + } } test("concat function - arrays") { @@ -689,6 +1322,233 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { intercept[AnalysisException] { df.selectExpr("concat(i1, array(i1, i2))") } + + val e = intercept[AnalysisException] { + df.selectExpr("concat(map(1, 2), map(3, 4))") + } + assert(e.getMessage.contains("string, binary or array")) + } + + test("flatten function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq(1), null)), + (Seq(null, Seq(1))), + (Seq(null, null)) + ).toDF("i") + + val intDFResult = Seq( + Row(Seq(1, 2, 3, 4, 5, 6)), + Row(Seq(1, 2)), + Row(Seq(1)), + Row(Seq(1)), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.filter(dummyFilter($"i"))select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq("a"), null)), + (Seq(null, Seq("a"))), + (Seq(null, null)) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a")), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.filter(dummyFilter($"s")).select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } + + test("array_repeat function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on + val strDF = Seq( + ("hi", 2), + (null, 2) + ).toDF("a", "b") + + val strDFTwiceResult = Seq( + Row(Seq("hi", "hi")), + Row(Seq(null, null)) + ) + + checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult) + + val intDF = { + val schema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", IntegerType))) + val data = Seq( + Row(3, 2), + Row(null, 2) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val intDFTwiceResult = Seq( + Row(Seq(3, 3)), + Row(Seq(null, null)) + ) + + checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult) + + val nullCountDF = { + val schema = StructType(Seq( + StructField("a", StringType), + StructField("b", IntegerType))) + val data = Seq( + Row("hi", null), + Row(null, null) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + checkAnswer( + nullCountDF.select(array_repeat($"a", $"b")), + Seq( + Row(null), + Row(null) + ) + ) + + // Error test cases + val invalidTypeDF = Seq(("hi", "1")).toDF("a", "b") + + intercept[AnalysisException] { + invalidTypeDF.select(array_repeat($"a", $"b")) + } + intercept[AnalysisException] { + invalidTypeDF.select(array_repeat($"a", lit("1"))) + } + intercept[AnalysisException] { + invalidTypeDF.selectExpr("array_repeat(a, 1.0)") + } + + } + + test("array remove") { + val df = Seq( + (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2) + ).toDF("a", "b", "c", "d") + checkAnswer( + df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")), + Seq( + Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), + Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), + Row(null, null, null)) + ) + + checkAnswer( + df.select(array_remove($"a", $"d")), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) + ) + + checkAnswer( + df.selectExpr("array_remove(a, d)"), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) + ) + + checkAnswer( + df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", + "array_remove(c, \"\")"), + Seq( + Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), + Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), + Row(null, null, null)) + ) + + val e = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_remove(_1, _2)") + } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) + } + + test("array_distinct functions") { + val df = Seq( + (Array[Int](2, 1, 3, 4, 3, 5), Array("b", "c", "a", "c", "b", "", "")), + (Array.empty[Int], Array.empty[String]), + (null, null) + ).toDF("a", "b") + checkAnswer( + df.select(array_distinct($"a"), array_distinct($"b")), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_distinct(a)", "array_distinct(b)"), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) } private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0d9eeabb397a1..10d9a11d2ee79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -287,4 +287,12 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.optimizedPlan } } + + test("SPARK-24385: Resolve ambiguity in self-joins with EqualNullSafe") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val df = spark.range(2) + // this throws an exception before the fix + df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index a0fd74088ce8b..b0b46640ff317 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql -import java.util.concurrent.{CountDownLatch, TimeUnit} - import scala.concurrent.duration._ import scala.math.abs import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -154,53 +152,35 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } test("Cancelling stage in a query with Range.") { - // Save and restore the value because SparkContext is shared - val savedInterruptOnCancel = sparkContext - .getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) - - try { - sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") - - for (codegen <- Seq(true, false)) { - // This countdown latch used to make sure with all the stages cancelStage called in listener - val latch = new CountDownLatch(2) - - val listener = new SparkListener { - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { - sparkContext.cancelStage(taskStart.stageId) - latch.countDown() - } - } + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sparkContext.cancelStage(taskStart.stageId) + } + } - sparkContext.addSparkListener(listener) - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - val ex = intercept[SparkException] { - sparkContext.range(0, 10000L, numSlices = 10).mapPartitions { x => - x.synchronized { - x.wait() - } - x - }.toDF("id").agg(sum("id")).collect() - } - ex.getCause() match { - case null => - assert(ex.getMessage().contains("cancelled")) - case cause: SparkException => - assert(cause.getMessage().contains("cancelled")) - case cause: Throwable => - fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") - } + sparkContext.addSparkListener(listener) + for (codegen <- Seq(true, false)) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + val ex = intercept[SparkException] { + spark.range(0, 100000000000L, 1, 1) + .toDF("id").agg(sum("id")).collect() } - latch.await(20, TimeUnit.SECONDS) - eventually(timeout(20.seconds)) { - assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + ex.getCause() match { + case null => + assert(ex.getMessage().contains("cancelled")) + case cause: SparkException => + assert(cause.getMessage().contains("cancelled")) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } - sparkContext.removeSparkListener(listener) } - } finally { - sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, - savedInterruptOnCancel) + // Wait until all ListenerBus events consumed to make sure cancelStage called for all stages + sparkContext.listenerBus.waitUntilEmpty(20.seconds.toMillis) + eventually(timeout(20.seconds)) { + assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } } + sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 60e84e6ee7504..5babdf6f33b99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1044,6 +1044,65 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select($"*").show(1000) } + test("getRows: truncate = [0, 20]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = Seq( + Seq("value"), + Seq("1"), + Seq("111111111111111111111")) + assert(df.getRows(10, 0) === expectedAnswerForFalse) + val expectedAnswerForTrue = Seq( + Seq("value"), + Seq("1"), + Seq("11111111111111111...")) + assert(df.getRows(10, 20) === expectedAnswerForTrue) + } + + test("getRows: truncate = [3, 17]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = Seq( + Seq("value"), + Seq("1"), + Seq("111")) + assert(df.getRows(10, 3) === expectedAnswerForFalse) + val expectedAnswerForTrue = Seq( + Seq("value"), + Seq("1"), + Seq("11111111111111...")) + assert(df.getRows(10, 17) === expectedAnswerForTrue) + } + + test("getRows: numRows = 0") { + val expectedAnswer = Seq(Seq("key", "value"), Seq("1", "1")) + assert(testData.select($"*").getRows(0, 20) === expectedAnswer) + } + + test("getRows: array") { + val df = Seq( + (Array(1, 2, 3), Array(1, 2, 3)), + (Array(2, 3, 4), Array(2, 3, 4)) + ).toDF() + val expectedAnswer = Seq( + Seq("_1", "_2"), + Seq("[1, 2, 3]", "[1, 2, 3]"), + Seq("[2, 3, 4]", "[2, 3, 4]")) + assert(df.getRows(10, 20) === expectedAnswer) + } + + test("getRows: binary") { + val df = Seq( + ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), + ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) + ).toDF() + val expectedAnswer = Seq( + Seq("_1", "_2"), + Seq("[31 32]", "[41 42 43 2E]"), + Seq("[33 34]", "[31 32 33 34 36]")) + assert(df.getRows(10, 20) === expectedAnswer) + } + test("showString: truncate = [0, 20]") { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() @@ -2261,8 +2320,96 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + test("SPARK-24165: CaseWhen/If - nullability of nested types") { + val rows = new java.util.ArrayList[Row]() + rows.add(Row(true, ("x", 1), Seq("x", "y"), Map(0 -> "x"))) + rows.add(Row(false, (null, 2), Seq(null, "z"), Map(0 -> null))) + val schema = StructType(Seq( + StructField("cond", BooleanType, true), + StructField("s", StructType(Seq( + StructField("val1", StringType, true), + StructField("val2", IntegerType, false) + )), false), + StructField("a", ArrayType(StringType, true)), + StructField("m", MapType(IntegerType, StringType, true)) + )) + + val sourceDF = spark.createDataFrame(rows, schema) + + val structWhenDF = sourceDF + .select(when('cond, struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise('s) as "res") + .select('res.getField("val1")) + val arrayWhenDF = sourceDF + .select(when('cond, array(lit("a"), lit("b"))).otherwise('a) as "res") + .select('res.getItem(0)) + val mapWhenDF = sourceDF + .select(when('cond, map(lit(0), lit("a"))).otherwise('m) as "res") + .select('res.getItem(0)) + + val structIfDF = sourceDF + .select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res") + .select('res.getField("val1")) + val arrayIfDF = sourceDF + .select(expr("if(cond, array('a', 'b'), a)") as "res") + .select('res.getItem(0)) + val mapIfDF = sourceDF + .select(expr("if(cond, map(0, 'a'), m)") as "res") + .select('res.getItem(0)) + + def checkResult(df: DataFrame, codegenExpected: Boolean): Unit = { + assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] == codegenExpected) + checkAnswer(df, Seq(Row("a"), Row(null))) + } + + // without codegen + checkResult(structWhenDF, false) + checkResult(arrayWhenDF, false) + checkResult(mapWhenDF, false) + checkResult(structIfDF, false) + checkResult(arrayIfDF, false) + checkResult(mapIfDF, false) + + // with codegen + checkResult(structWhenDF.filter('cond.isNotNull), true) + checkResult(arrayWhenDF.filter('cond.isNotNull), true) + checkResult(mapWhenDF.filter('cond.isNotNull), true) + checkResult(structIfDF.filter('cond.isNotNull), true) + checkResult(arrayIfDF.filter('cond.isNotNull), true) + checkResult(mapIfDF.filter('cond.isNotNull), true) + } + test("Uuid expressions should produce same results at retries in the same DataFrame") { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) } + + test("SPARK-24313: access map with binary keys") { + val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) + checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) + } + + test("SPARK-24781: Using a reference from Dataset in Filter/Sort") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + val filter1 = df.select(df("name")).filter(df("id") === 0) + val filter2 = df.select(col("name")).filter(col("id") === 0) + checkAnswer(filter1, filter2.collect()) + + val sort1 = df.select(df("name")).orderBy(df("id")) + val sort2 = df.select(col("name")).orderBy(col("id")) + checkAnswer(sort1, sort2.collect()) + } + + test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") { + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + + val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name")) + val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name")) + checkAnswer(aggPlusSort1, aggPlusSort2.collect()) + + val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0) + val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0) + checkAnswer(aggPlusFilter1, aggPlusFilter2.collect()) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 0ee9b0edc02b2..2a0b2b85e10a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -402,4 +402,18 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: Row(10, 6000) :: Nil) } + + test("SPARK-24033: Analysis Failure of OffsetWindowFunction") { + val ds = Seq((1, 1), (1, 2), (1, 3), (2, 1), (2, 2)).toDF("n", "i") + val res = + Row(1, 1, null) :: Row (1, 2, 1) :: Row(1, 3, 2) :: Row(2, 1, null) :: Row(2, 2, 1) :: Nil + checkAnswer( + ds.withColumn("m", + lead("i", -1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + checkAnswer( + ds.withColumn("m", + lag("i", 1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 3ea398aad7375..97a843978f0bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql -import java.sql.{Date, Timestamp} - -import scala.collection.mutable +import org.scalatest.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} @@ -27,7 +25,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval /** * Window function testing for DataFrame API. @@ -624,4 +621,41 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-24575: Window functions inside WHERE and HAVING clauses") { + def checkAnalysisError(df: => DataFrame): Unit = { + val thrownException = the [AnalysisException] thrownBy { + df.queryExecution.analyzed + } + assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses")) + } + + checkAnalysisError(testData2.select('a).where(rank().over(Window.orderBy('b)) === 1)) + checkAnalysisError(testData2.where('b === 2 && rank().over(Window.orderBy('b)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(avg('b).as("avgb")) + .where('a > 'avgb && rank().over(Window.orderBy('a)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(max('b).as("maxb"), sum('b).as("sumb")) + .where(rank().over(Window.orderBy('a)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(max('b).as("maxb"), sum('b).as("sumb")) + .where('sumb === 5 && rank().over(Window.orderBy('a)) === 1)) + + checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError( + sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1")) + checkAnalysisError( + sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1")) + checkAnalysisError( + sql( + s"""SELECT a, MAX(b) + |FROM testData2 + |GROUP BY a + |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 0e7eaa9e88d57..538ea3c66c40e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -148,6 +148,41 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { } +case class OptionBooleanData(name: String, isGood: Option[Boolean]) + +case class OptionBooleanAggregator(colName: String) + extends Aggregator[Row, Option[Boolean], Option[Boolean]] { + + override def zero: Option[Boolean] = None + + override def reduce(buffer: Option[Boolean], row: Row): Option[Boolean] = { + val index = row.fieldIndex(colName) + val value = if (row.isNullAt(index)) { + Option.empty[Boolean] + } else { + Some(row.getBoolean(index)) + } + merge(buffer, value) + } + + override def merge(b1: Option[Boolean], b2: Option[Boolean]): Option[Boolean] = { + if ((b1.isDefined && b1.get) || (b2.isDefined && b2.get)) { + Some(true) + } else if (b1.isDefined) { + b1 + } else { + b2 + } + } + + override def finish(reduction: Option[Boolean]): Option[Boolean] = reduction + + override def bufferEncoder: Encoder[Option[Boolean]] = OptionalBoolEncoder + override def outputEncoder: Encoder[Option[Boolean]] = OptionalBoolEncoder + + def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -333,4 +368,29 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { df.groupBy($"i").agg(VeryComplexResultAgg.toColumn), Row(1, Row(Row(1, "a"), Row(1, "a"))) :: Row(2, Row(Row(2, "bc"), Row(2, "bc"))) :: Nil) } + + test("SPARK-24569: Aggregator with output type Option[Boolean] creates column of type Row") { + val df = Seq( + OptionBooleanData("bob", Some(true)), + OptionBooleanData("bob", Some(false)), + OptionBooleanData("bob", None)).toDF() + val group = df + .groupBy("name") + .agg(OptionBooleanAggregator("isGood").toColumn.alias("isGood")) + assert(df.schema == group.schema) + checkAnswer(group, Row("bob", true) :: Nil) + checkDataset(group.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) + } + + test("SPARK-24569: groupByKey with Aggregator of output type Option[Boolean]") { + val df = Seq( + OptionBooleanData("bob", Some(true)), + OptionBooleanData("bob", Some(false)), + OptionBooleanData("bob", None)).toDF() + val grouped = df.groupByKey((r: Row) => r.getString(0)) + .agg(OptionBooleanAggregator("isGood").toColumn).toDF("name", "isGood") + + assert(grouped.schema == df.schema) + checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index e0561ee2797a5..5c6a021d5b767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -17,14 +17,28 @@ package org.apache.spark.sql +import org.scalatest.concurrent.TimeLimits +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.StorageLevel -class DatasetCacheSuite extends QueryTest with SharedSQLContext { +class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits { import testImplicits._ + /** + * Asserts that a cached [[Dataset]] will be built using the given number of other cached results. + */ + private def assertCacheDependency(df: DataFrame, numOfCachesDependedUpon: Int = 1): Unit = { + val plan = df.queryExecution.withCachedData + assert(plan.isInstanceOf[InMemoryRelation]) + val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan + assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).size == numOfCachesDependedUpon) + } + test("get storage level") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") @@ -96,4 +110,100 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { agged.unpersist() assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.") } + + test("persist and then withColumn") { + val df = Seq(("test", 1)).toDF("s", "i") + val df2 = df.withColumn("newColumn", lit(1)) + + df.cache() + assertCached(df) + assertCached(df2) + + df.count() + assertCached(df2) + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } + + test("cache UDF result correctly") { + val expensiveUDF = udf({x: Int => Thread.sleep(5000); x}) + val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + assertCached(df2) + + // udf has been evaluated during caching, and thus should not be re-evaluated here + failAfter(3 seconds) { + df2.collect() + } + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } + + test("SPARK-24613 Cache with UDF could not be matched with subsequent dependent caches") { + val udf1 = udf({x: Int => x + 1}) + val df = spark.range(0, 10).toDF("a").withColumn("b", udf1($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + df2.cache() + + assertCacheDependency(df2) + } + + test("SPARK-24596 Non-cascading Cache Invalidation") { + val df = Seq(("a", 1), ("b", 2)).toDF("s", "i") + val df2 = df.filter('i > 1) + val df3 = df.filter('i < 2) + + df2.cache() + df.cache() + df.count() + df3.cache() + + df.unpersist() + + // df un-cached; df2 and df3's cache plan re-compiled + assert(df.storageLevel == StorageLevel.NONE) + assertCacheDependency(df2, 0) + assertCacheDependency(df3, 0) + } + + test("SPARK-24596 Non-cascading Cache Invalidation - verify cached data reuse") { + val expensiveUDF = udf({ x: Int => Thread.sleep(5000); x }) + val df = spark.range(0, 5).toDF("a") + val df1 = df.withColumn("b", expensiveUDF($"a")) + val df2 = df1.groupBy('a).agg(sum('b)) + val df3 = df.agg(sum('a)) + + df1.cache() + df2.cache() + df2.collect() + df3.cache() + + assertCacheDependency(df2) + + df1.unpersist(blocking = true) + + // df1 un-cached; df2's cache plan re-compiled + assert(df1.storageLevel == StorageLevel.NONE) + assertCacheDependency(df1.groupBy('a).agg(sum('b)), 0) + + val df4 = df1.groupBy('a).agg(sum('b)).agg(sum("sum(b)")) + assertCached(df4) + // reuse loaded cache + failAfter(3 seconds) { + checkDataset(df4, Row(10)) + } + + val df5 = df.agg(sum('a)).filter($"sum(a)" > 1) + assertCached(df5) + // first time use, load cache + checkDataset(df5, Row(10)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e0f4d2ba685e1..ce8db99d4e2f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1425,6 +1425,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23627: provide isEmpty in DataSet") { + val ds1 = spark.emptyDataset[Int] + val ds2 = Seq(1, 2, 3).toDS() + + assert(ds1.isEmpty == true) + assert(ds2.isEmpty == false) + } + test("SPARK-22472: add null check for top-level primitive values") { // If the primitive values are from Option, we need to do runtime null check. val ds = Seq(Some(1), None).toDS().as[Int] @@ -1458,6 +1466,38 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + + test("SPARK-24569: Option of primitive types are mistakenly mapped to struct type") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val a = Seq(Some(1)).toDS + val b = Seq(Some(1.2)).toDS + val expected = Seq((Some(1), Some(1.2))).toDS + val joined = a.joinWith(b, lit(true)) + assert(joined.schema == expected.schema) + checkDataset(joined, expected.collect: _*) + } + } + + test("SPARK-24548: Dataset with tuple encoders should have correct schema") { + val encoder = Encoders.tuple(newStringEncoder, + Encoders.tuple(newStringEncoder, newStringEncoder)) + + val data = Seq(("a", ("1", "2")), ("b", ("3", "4"))) + val rdd = sparkContext.parallelize(data) + + val ds1 = spark.createDataset(rdd) + val ds2 = spark.createDataset(rdd)(encoder) + assert(ds1.schema == ds2.schema) + checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*) + } + + test("SPARK-24571: filtering of string values by char literal") { + val df = Seq("Amsterdam", "San Francisco", "X").toDF("city") + checkAnswer(df.where('city === 'X'), Seq(Row("X"))) + checkAnswer( + df.where($"city".contains(new java.lang.Character('A'))), + Seq(Row("Amsterdam"))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 6bbf38516cdf6..3af80b36ec42c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -23,6 +23,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval @@ -327,6 +328,13 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + checkAnswer(df.selectExpr("months_between(t, s, true)"), Seq(Row(0.5), Row(-0.5))) + Seq(true, false).foreach { roundOff => + checkAnswer(df.select(months_between(col("t"), col("d"), roundOff)), + Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.withColumn("r", lit(false)).selectExpr("months_between(t, s, r)"), + Seq(Row(0.5), Row(-0.5))) + } } test("function last_day") { @@ -655,7 +663,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("datediff(a, d)"), Seq(Row(1), Row(1))) } - test("from_utc_timestamp") { + test("from_utc_timestamp with literal zone") { val df = Seq( (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") @@ -672,7 +680,24 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-24 17:00:00")))) } - test("to_utc_timestamp") { + test("from_utc_timestamp with column zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "CET"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "PST") + ).toDF("a", "b", "c") + checkAnswer( + df.select(from_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + checkAnswer( + df.select(from_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + } + + test("to_utc_timestamp with literal zone") { val df = Seq( (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") @@ -689,4 +714,28 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-25 07:00:00")))) } + test("to_utc_timestamp with column zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "PST"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "CET") + ).toDF("a", "b", "c") + checkAnswer( + df.select(to_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + checkAnswer( + df.select(to_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + } + + test("SPARK-23715: to/from_utc_timestamp can retain the previous behavior") { + withSQLConf(SQLConf.REJECT_TIMEZONE_IN_STRING.key -> "false") { + checkAnswer( + sql("SELECT from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')"), + Row(Timestamp.valueOf("2000-10-09 18:00:00"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 06303099f5310..a7ce952b70ac1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql -import java.io.FileNotFoundException +import java.io.{File, FileNotFoundException} +import java.util.Locale import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException +import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -202,4 +204,258 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } } + + // Text file format only supports string type + test("SPARK-24691 error handling for unsupported types - text") { + withTempDir { dir => + // write path + val textDir = new File(dir, "text").getCanonicalPath + var msg = intercept[AnalysisException] { + Seq(1).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + Seq(1.2).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + Seq(true).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + + msg = intercept[AnalysisException] { + Seq(1).toDF("a").selectExpr("struct(a)").write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support struct data type")) + + msg = intercept[AnalysisException] { + Seq((Map("Tesla" -> 3))).toDF("cars").write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support map data type")) + + msg = intercept[AnalysisException] { + Seq((Array("Tesla", "Chevy", "Ford"))).toDF("brands") + .write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support array data type")) + + // read path + Seq("aaa").toDF.write.mode("overwrite").text(textDir) + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", DoubleType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", BooleanType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + } + } + + // Unsupported data types of csv, json, orc, and parquet are as follows; + // csv -> R/W: Null, Array, Map, Struct + // json -> R/W: Interval + // orc -> R/W: Interval, W: Null + // parquet -> R/W: Interval, Null + test("SPARK-24204 error handling for unsupported Array/Map/Struct types - csv") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + var msg = intercept[AnalysisException] { + Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support struct data type")) + + msg = intercept[AnalysisException] { + val schema = StructType.fromDDL("a struct") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support struct data type")) + + msg = intercept[AnalysisException] { + Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support map data type")) + + msg = intercept[AnalysisException] { + val schema = StructType.fromDDL("a map") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support map data type")) + + msg = intercept[AnalysisException] { + Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands") + .write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[AnalysisException] { + val schema = StructType.fromDDL("a array") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[AnalysisException] { + Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") + .write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support mydensevector data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support mydensevector data type.")) + } + } + + test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + // write path + Seq("csv", "json", "parquet", "orc").foreach { format => + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.contains("Cannot save interval data type into external storage.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new IntervalData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support interval data type.")) + } + + // read path + Seq("parquet", "csv").foreach { format => + var msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support interval data type.")) + } + } + } + + test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + Seq("orc").foreach { format => + // write path + var msg = intercept[AnalysisException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + // read path + // We expect the types below should be passed for backward-compatibility + + // Null type + var schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + + // UDT having null data + schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + } + + Seq("parquet", "csv").foreach { format => + // write path + var msg = intercept[AnalysisException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + // read path + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + } + } + } +} + +object TestingUDT { + + @SQLUserDefinedType(udt = classOf[IntervalUDT]) + class IntervalData extends Serializable + + class IntervalUDT extends UserDefinedType[IntervalData] { + + override def sqlType: DataType = CalendarIntervalType + override def serialize(obj: IntervalData): Any = + throw new NotImplementedError("Not implemented") + override def deserialize(datum: Any): IntervalData = + throw new NotImplementedError("Not implemented") + override def userClass: Class[IntervalData] = classOf[IntervalData] + } + + @SQLUserDefinedType(udt = classOf[NullUDT]) + private[sql] class NullData extends Serializable + + private[sql] class NullUDT extends UserDefinedType[NullData] { + + override def sqlType: DataType = NullType + override def serialize(obj: NullData): Any = throw new NotImplementedError("Not implemented") + override def deserialize(datum: Any): NullData = + throw new NotImplementedError("Not implemented") + override def userClass: Class[NullData] = classOf[NullData] + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala deleted file mode 100644 index c6dd7dadc9d93..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala +++ /dev/null @@ -1,243 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.io.File - -import scala.util.{Random, Try} - -import org.apache.spark.SparkConf -import org.apache.spark.sql.functions.monotonically_increasing_id -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{Benchmark, Utils} - - -/** - * Benchmark to measure read performance with Filter pushdown. - */ -object FilterPushdownBenchmark { - val conf = new SparkConf() - conf.set("orc.compression", "snappy") - conf.set("spark.sql.parquet.compression.codec", "snappy") - - private val spark = SparkSession.builder() - .master("local[1]") - .appName("FilterPushdownBenchmark") - .config(conf) - .getOrCreate() - - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(spark.catalog.dropTempView) - } - - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - - private def prepareTable(dir: File, numRows: Int, width: Int): Unit = { - import spark.implicits._ - val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") - val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) - .withColumn("id", monotonically_increasing_id()) - - val dirORC = dir.getCanonicalPath + "/orc" - val dirParquet = dir.getCanonicalPath + "/parquet" - - df.write.mode("overwrite").orc(dirORC) - df.write.mode("overwrite").parquet(dirParquet) - - spark.read.orc(dirORC).createOrReplaceTempView("orcTable") - spark.read.parquet(dirParquet).createOrReplaceTempView("parquetTable") - } - - def filterPushDownBenchmark( - values: Int, - title: String, - whereExpr: String, - selectExpr: String = "*"): Unit = { - val benchmark = new Benchmark(title, values, minNumIters = 5) - - Seq(false, true).foreach { pushDownEnabled => - val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" - benchmark.addCase(name) { _ => - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { - spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() - } - } - } - - Seq(false, true).foreach { pushDownEnabled => - val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" - benchmark.addCase(name) { _ => - withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { - spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() - } - } - } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz - - Select 0 row (id IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7882 / 7957 2.0 501.1 1.0X - Parquet Vectorized (Pushdown) 55 / 60 285.2 3.5 142.9X - Native ORC Vectorized 5592 / 5627 2.8 355.5 1.4X - Native ORC Vectorized (Pushdown) 66 / 70 237.2 4.2 118.9X - - Select 0 row (7864320 < id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7884 / 7909 2.0 501.2 1.0X - Parquet Vectorized (Pushdown) 739 / 752 21.3 47.0 10.7X - Native ORC Vectorized 5614 / 5646 2.8 356.9 1.4X - Native ORC Vectorized (Pushdown) 81 / 83 195.2 5.1 97.8X - - Select 1 row (id = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7905 / 8027 2.0 502.6 1.0X - Parquet Vectorized (Pushdown) 740 / 766 21.2 47.1 10.7X - Native ORC Vectorized 5684 / 5738 2.8 361.4 1.4X - Native ORC Vectorized (Pushdown) 78 / 81 202.4 4.9 101.7X - - Select 1 row (id <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7928 / 7993 2.0 504.1 1.0X - Parquet Vectorized (Pushdown) 747 / 772 21.0 47.5 10.6X - Native ORC Vectorized 5728 / 5753 2.7 364.2 1.4X - Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 104.8X - - Select 1 row (7864320 <= id <= 7864320):Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7939 / 8021 2.0 504.8 1.0X - Parquet Vectorized (Pushdown) 746 / 770 21.1 47.4 10.6X - Native ORC Vectorized 5690 / 5734 2.8 361.7 1.4X - Native ORC Vectorized (Pushdown) 76 / 79 206.7 4.8 104.3X - - Select 1 row (7864319 < id < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7972 / 8019 2.0 506.9 1.0X - Parquet Vectorized (Pushdown) 742 / 764 21.2 47.2 10.7X - Native ORC Vectorized 5704 / 5743 2.8 362.6 1.4X - Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 105.4X - - Select 10% rows (id < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 8733 / 8808 1.8 555.2 1.0X - Parquet Vectorized (Pushdown) 2213 / 2267 7.1 140.7 3.9X - Native ORC Vectorized 6420 / 6463 2.4 408.2 1.4X - Native ORC Vectorized (Pushdown) 1313 / 1331 12.0 83.5 6.7X - - Select 50% rows (id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 11518 / 11591 1.4 732.3 1.0X - Parquet Vectorized (Pushdown) 7962 / 7991 2.0 506.2 1.4X - Native ORC Vectorized 8927 / 8985 1.8 567.6 1.3X - Native ORC Vectorized (Pushdown) 6102 / 6160 2.6 387.9 1.9X - - Select 90% rows (id < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14255 / 14389 1.1 906.3 1.0X - Parquet Vectorized (Pushdown) 13564 / 13594 1.2 862.4 1.1X - Native ORC Vectorized 11442 / 11608 1.4 727.5 1.2X - Native ORC Vectorized (Pushdown) 10991 / 11029 1.4 698.8 1.3X - - Select all rows (id IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14917 / 14938 1.1 948.4 1.0X - Parquet Vectorized (Pushdown) 14910 / 14964 1.1 948.0 1.0X - Native ORC Vectorized 11986 / 12069 1.3 762.0 1.2X - Native ORC Vectorized (Pushdown) 12037 / 12123 1.3 765.3 1.2X - - Select all rows (id > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14951 / 14976 1.1 950.6 1.0X - Parquet Vectorized (Pushdown) 14934 / 15016 1.1 949.5 1.0X - Native ORC Vectorized 12000 / 12156 1.3 763.0 1.2X - Native ORC Vectorized (Pushdown) 12079 / 12113 1.3 767.9 1.2X - - Select all rows (id != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14930 / 14972 1.1 949.3 1.0X - Parquet Vectorized (Pushdown) 15015 / 15047 1.0 954.6 1.0X - Native ORC Vectorized 12090 / 12259 1.3 768.7 1.2X - Native ORC Vectorized (Pushdown) 12021 / 12096 1.3 764.2 1.2X - */ - benchmark.run() - } - - def main(args: Array[String]): Unit = { - val numRows = 1024 * 1024 * 15 - val width = 5 - val mid = numRows / 2 - - withTempPath { dir => - withTempTable("orcTable", "patquetTable") { - prepareTable(dir, numRows, width) - - Seq("id IS NULL", s"$mid < id AND id < $mid").foreach { whereExpr => - val title = s"Select 0 row ($whereExpr)".replace("id AND id", "id") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - Seq( - s"id = $mid", - s"id <=> $mid", - s"$mid <= id AND id <= $mid", - s"${mid - 1} < id AND id < ${mid + 1}" - ).foreach { whereExpr => - val title = s"Select 1 row ($whereExpr)".replace("id AND id", "id") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(id)") - - Seq(10, 50, 90).foreach { percent => - filterPushDownBenchmark( - numRows, - s"Select $percent% rows (id < ${numRows * percent / 100})", - s"id < ${numRows * percent / 100}", - selectExpr - ) - } - - Seq("id IS NOT NULL", "id > -1", "id != -1").foreach { whereExpr => - filterPushDownBenchmark( - numRows, - s"Select all rows ($whereExpr)", - whereExpr, - selectExpr) - } - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 109fcf90a3ec9..8280a3ce39845 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} @@ -315,6 +316,7 @@ case class EmptyGenerator() extends Generator { override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val iteratorClass = classOf[Iterator[_]].getName - ev.copy(code = s"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") + ev.copy(code = + code"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala new file mode 100644 index 0000000000000..147c0b61f5017 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.logical.AnalysisBarrier +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{LongType, StructField, StructType} + +class GroupedDatasetSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private val scalaUDF = udf((x: Long) => { x + 1 }) + private lazy val datasetWithUDF = spark.range(1).toDF("s").select($"s", scalaUDF($"s")) + + private def assertContainsAnalysisBarrier(ds: Dataset[_], atLevel: Int = 1): Unit = { + assert(atLevel >= 0) + var children = Seq(ds.queryExecution.logical) + (1 to atLevel).foreach { _ => + children = children.flatMap(_.children) + } + val barriers = children.collect { + case ab: AnalysisBarrier => ab + } + assert(barriers.nonEmpty, s"Plan does not contain AnalysisBarrier at level $atLevel:\n" + + ds.queryExecution.logical) + } + + test("SPARK-24373: avoid running Analyzer rules twice on RelationalGroupedDataset") { + val groupByDataset = datasetWithUDF.groupBy() + val rollupDataset = datasetWithUDF.rollup("s") + val cubeDataset = datasetWithUDF.cube("s") + val pivotDataset = datasetWithUDF.groupBy().pivot("s", Seq(1, 2)) + datasetWithUDF.cache() + Seq(groupByDataset, rollupDataset, cubeDataset, pivotDataset).foreach { rgDS => + val df = rgDS.count() + assertContainsAnalysisBarrier(df) + assertCached(df) + } + + val flatMapGroupsInRDF = datasetWithUDF.groupBy().flatMapGroupsInR( + Array.emptyByteArray, + Array.emptyByteArray, + Array.empty, + StructType(Seq(StructField("s", LongType)))) + val flatMapGroupsInPandasDF = datasetWithUDF.groupBy().flatMapGroupsInPandas(PythonUDF( + "pyUDF", + null, + StructType(Seq(StructField("s", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true)) + Seq(flatMapGroupsInRDF, flatMapGroupsInPandasDF).foreach { df => + assertContainsAnalysisBarrier(df, 2) + assertCached(df) + } + datasetWithUDF.unpersist(true) + } + + test("SPARK-24373: avoid running Analyzer rules twice on KeyValueGroupedDataset") { + val kvDasaset = datasetWithUDF.groupByKey(_.getLong(0)) + datasetWithUDF.cache() + val mapValuesKVDataset = kvDasaset.mapValues(_.getLong(0)).reduceGroups(_ + _) + val keysKVDataset = kvDasaset.keys + val flatMapGroupsKVDataset = kvDasaset.flatMapGroups((k, _) => Seq(k)) + val aggKVDataset = kvDasaset.count() + val otherKVDataset = spark.range(1).groupByKey(_ + 1) + val cogroupKVDataset = kvDasaset.cogroup(otherKVDataset)((k, _, _) => Seq(k)) + Seq((mapValuesKVDataset, 1), + (keysKVDataset, 2), + (flatMapGroupsKVDataset, 2), + (aggKVDataset, 1), + (cogroupKVDataset, 2)).foreach { case (df, analysisBarrierDepth) => + assertContainsAnalysisBarrier(df, analysisBarrierDepth) + assertCached(df) + } + datasetWithUDF.unpersist(true) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 771e1186e63ab..44767dfc92497 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -239,7 +239,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(2, 2, 1, null) :: Row(2, 2, 2, 2) :: Nil) } - assert(e.getMessage.contains("Detected cartesian product for INNER join " + + assert(e.getMessage.contains("Detected implicit cartesian product for INNER join " + "between logical plans")) } } @@ -611,7 +611,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val e = intercept[Exception] { checkAnswer(sql(query), Nil); } - assert(e.getMessage.contains("Detected cartesian product")) + assert(e.getMessage.contains("Detected implicit cartesian product")) } cartesianQueries.foreach(checkCartesianDetection) @@ -882,4 +882,15 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil) } } + + test("SPARK-24495: Join may return wrong result when having duplicated equal-join keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(0, 100, 1, 2) + val df2 = spark.range(100).select($"id".as("b1"), (- $"id").as("b2")) + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id") + checkAnswer(res, Row(0, 0, 0)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 00d2acc4a1d8a..d3b2701f2558e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.{from_json, lit, map, struct, to_json} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -311,7 +311,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { val errMsg1 = intercept[AnalysisException] { df3.selectExpr("from_json(value, 1)") } - assert(errMsg1.getMessage.startsWith("Expected a string literal instead of")) + assert(errMsg1.getMessage.startsWith("Schema should be specified in DDL format as a string")) val errMsg2 = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") } @@ -326,4 +326,83 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg4.getMessage.startsWith( "A type of keys and values in map() must be string, but got")) } + + test("SPARK-24027: from_json - map") { + val in = Seq("""{"a": 1, "b": 2, "c": 3}""").toDS() + val schema = + """ + |{ + | "type" : "map", + | "keyType" : "string", + | "valueType" : "integer", + | "valueContainsNull" : true + |} + """.stripMargin + val out = in.select(from_json($"value", schema, Map[String, String]())) + + assert(out.columns.head == "entries") + checkAnswer(out, Row(Map("a" -> 1, "b" -> 2, "c" -> 3))) + } + + test("SPARK-24027: from_json - map") { + val in = Seq("""{"a": {"b": 1}}""").toDS() + val schema = MapType(StringType, new StructType().add("b", IntegerType), true) + val out = in.select(from_json($"value", schema)) + + checkAnswer(out, Row(Map("a" -> Row(1)))) + } + + test("SPARK-24027: from_json - map>") { + val in = Seq("""{"a": {"b": 1}}""").toDS() + val schema = "map>" + val out = in.select(from_json($"value", schema, Map.empty[String, String])) + + checkAnswer(out, Row(Map("a" -> Map("b" -> 1)))) + } + + test("SPARK-24027: roundtrip - from_json -> to_json - map") { + val json = """{"a":1,"b":2,"c":3}""" + val schema = MapType(StringType, IntegerType, true) + val out = Seq(json).toDS().select(to_json(from_json($"value", schema))) + + checkAnswer(out, Row(json)) + } + + test("SPARK-24027: roundtrip - to_json -> from_json - map") { + val in = Seq(Map("a" -> 1)).toDF() + val schema = MapType(StringType, IntegerType, true) + val out = in.select(from_json(to_json($"value"), schema)) + + checkAnswer(out, in) + } + + test("SPARK-24027: from_json - wrong map") { + val in = Seq("""{"a" 1}""").toDS() + val schema = MapType(StringType, IntegerType) + val out = in.select(from_json($"value", schema, Map[String, String]())) + + checkAnswer(out, Row(null)) + } + + test("SPARK-24027: from_json of a map with unsupported key type") { + val schema = MapType(StructType(StructField("f", IntegerType) :: Nil), StringType) + + checkAnswer(Seq("""{{"f": 1}: "a"}""").toDS().select(from_json($"value", schema)), + Row(null)) + checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)), + Row(null)) + } + + test("SPARK-24709: infers schemas of json strings and pass them to from_json") { + val in = Seq("""{"a": [1, 2, 3]}""").toDS() + val out = in.select(from_json('value, schema_of_json(lit("""{"a": [1]}"""))) as "parsed") + val expected = StructType(StructField( + "parsed", + StructType(StructField( + "a", + ArrayType(LongType, true), true) :: Nil), + true) :: Nil) + + assert(out.schema == expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala index d66a6902b0510..cbef1c7828319 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala @@ -30,11 +30,15 @@ trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self override def beforeAll() { super.beforeAll() InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } override def afterEach() { try { resetSparkContext() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } finally { super.afterEach() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index cfe2e9f2dbc44..cdcea09ad9758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -54,4 +54,18 @@ class RuntimeConfigSuite extends SparkFunSuite { conf.get("k1") } } + + test("SPARK-24761: is a config parameter modifiable") { + val conf = newConf() + + // SQL configs + assert(!conf.isModifiable("spark.sql.sources.schemaStringLengthThreshold")) + assert(conf.isModifiable("spark.sql.streaming.checkpointLocation")) + // Core configs + assert(!conf.isModifiable("spark.task.cpus")) + assert(!conf.isModifiable("spark.executor.cores")) + // Invalid config parameters + assert(!conf.isModifiable("")) + assert(!conf.isModifiable("invalid config parameter")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 640affc10ee58..dfb9c137b74f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2792,4 +2792,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-24696 ColumnPruning rule fails to remove extra Project") { + withTable("fact_stats", "dim_stats") { + val factData = Seq((1, 1, 99, 1), (2, 2, 99, 2), (3, 1, 99, 3), (4, 2, 99, 4)) + val storeData = Seq((1, "BW", "DE"), (2, "AZ", "US")) + spark.udf.register("filterND", udf((value: Int) => value > 2).asNondeterministic) + factData.toDF("date_id", "store_id", "product_id", "units_sold") + .write.mode("overwrite").partitionBy("store_id").format("parquet").saveAsTable("fact_stats") + storeData.toDF("store_id", "state_province", "country") + .write.mode("overwrite").format("parquet").saveAsTable("dim_stats") + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.store_id FROM + |(SELECT date_id, product_id, store_id + | FROM fact_stats WHERE filterND(date_id)) AS f + |JOIN dim_stats s + |ON f.store_id = s.store_id WHERE s.country = 'DE' + """.stripMargin) + checkAnswer(df, Seq(Row(3, 99, 1))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index beac9699585d5..826408c7161e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -54,6 +54,7 @@ import org.apache.spark.sql.types.StructType * The format for input files is simple: * 1. A list of SQL queries separated by semicolon. * 2. Lines starting with -- are treated as comments and ignored. + * 3. Lines starting with --SET are used to run the file with the following set of configs. * * For example: * {{{ @@ -138,18 +139,58 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { private def runTest(testCase: TestCase): Unit = { val input = fileToString(new File(testCase.inputFile)) + val (comments, code) = input.split("\n").partition(_.startsWith("--")) + val configSets = { + val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5)) + val configs = configLines.map(_.split(",").map { confAndValue => + val (conf, value) = confAndValue.span(_ != '=') + conf.trim -> value.substring(1).trim + }) + // When we are regenerating the golden files we don't need to run all the configs as they + // all need to return the same result + if (regenerateGoldenFiles && configs.nonEmpty) { + configs.take(1) + } else { + configs + } + } // List of SQL queries to run - val queries: Seq[String] = { - val cleaned = input.split("\n").filterNot(_.startsWith("--")).mkString("\n") - // note: this is not a robust way to split queries using semicolon, but works for now. - cleaned.split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + // note: this is not a robust way to split queries using semicolon, but works for now. + val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + + if (configSets.isEmpty) { + runQueries(queries, testCase.resultFile, None) + } else { + configSets.foreach { configSet => + try { + runQueries(queries, testCase.resultFile, Some(configSet)) + } catch { + case e: Throwable => + val configs = configSet.map { + case (k, v) => s"$k=$v" + } + logError(s"Error using configs: ${configs.mkString(",")}") + throw e + } + } } + } + private def runQueries( + queries: Seq[String], + resultFileName: String, + configSet: Option[Seq[(String, String)]]): Unit = { // Create a local SparkSession to have stronger isolation between different test cases. // This does not isolate catalog changes. val localSparkSession = spark.newSession() loadTestData(localSparkSession) + if (configSet.isDefined) { + // Execute the list of set operation in order to add the desired configs + val setOperations = configSet.get.map { case (key, value) => s"set $key=$value" } + logInfo(s"Setting configs: ${setOperations.mkString(", ")}") + setOperations.foreach(localSparkSession.sql) + } // Run the SQL queries preparing them for comparison. val outputs: Seq[QueryOutput] = queries.map { sql => val (schema, output) = getNormalizedResult(localSparkSession, sql) @@ -167,7 +208,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { s"-- Number of queries: ${outputs.size}\n\n\n" + outputs.zipWithIndex.map{case (qr, i) => qr.toString(i)}.mkString("\n\n\n") + "\n" } - val resultFile = new File(testCase.resultFile) + val resultFile = new File(resultFileName) val parent = resultFile.getParentFile if (!parent.exists()) { assert(parent.mkdirs(), "Could not create directory: " + parent) @@ -177,7 +218,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { // Read back the golden file. val expectedOutputs: Seq[QueryOutput] = { - val goldenOutput = fileToString(new File(testCase.resultFile)) + val goldenOutput = fileToString(new File(resultFileName)) val segments = goldenOutput.split("-- !query.+\n") // each query has 3 segments, plus the header diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 14a565863d66c..60fa951e23178 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -50,7 +50,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") - assert(sizes.head === BigInt(96), + assert(sizes.head === BigInt(128), s"expected exact size 96 for table 'test', got: ${sizes.head}") } } @@ -382,4 +382,32 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } } + + test("Simple queries must be working, if CBO is turned on") { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { + withTable("TBL1", "TBL") { + import org.apache.spark.sql.functions._ + val df = spark.range(1000L).select('id, + 'id * 2 as "FLD1", + 'id * 12 as "FLD2", + lit("aaa") + 'id as "fld3") + df.write + .mode(SaveMode.Overwrite) + .bucketBy(10, "id", "FLD1", "FLD2") + .sortBy("id", "FLD1", "FLD2") + .saveAsTable("TBL") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS ") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS FOR COLUMNS ID, FLD1, FLD2, FLD3") + val df2 = spark.sql( + """ + |SELECT t1.id, t1.fld1, t1.fld2, t1.fld3 + |FROM tbl t1 + |JOIN tbl t2 on t1.id=t2.id + |WHERE t1.fld3 IN (-123.23,321.23) + """.stripMargin) + df2.createTempView("TBL2") + sql("SELECT * FROM tbl2 WHERE fld3 IN ('qqq', 'qwe') ").queryExecution.executedPlan + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 31e8b0e8dede0..acef62d81ee12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -955,4 +955,19 @@ class SubquerySuite extends QueryTest with SharedSQLContext { // before the fix this would throw AnalysisException spark.range(10).where("(id,id) in (select id, null from range(3))").count } + + test("SPARK-24085 scalar subquery in partitioning expression") { + withTable("parquet_part") { + Seq("1" -> "a", "2" -> "a", "3" -> "b", "4" -> "b") + .toDF("id_value", "id_type") + .write + .mode(SaveMode.Overwrite) + .partitionBy("id_type") + .format("parquet") + .saveAsTable("parquet_part") + checkAnswer( + sql("SELECT * FROM parquet_part WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 40915a102bab0..d254345e8fa54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -69,6 +69,27 @@ class PlannerSuite extends SharedSQLContext { testPartialAggregationPlan(query) } + test("mixed aggregates with same distinct columns") { + def assertNoExpand(plan: SparkPlan): Unit = { + assert(plan.collect { case e: ExpandExec => e }.isEmpty) + } + + withTempView("v") { + Seq((1, 1.0, 1.0), (1, 2.0, 2.0)).toDF("i", "j", "k").createTempView("v") + // one distinct column + val query1 = sql("SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i") + assertNoExpand(query1.queryExecution.executedPlan) + + // 2 distinct columns + val query2 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT j, k) FROM v GROUP BY i") + assertNoExpand(query2.queryExecution.executedPlan) + + // 2 distinct columns with different order + val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i") + assertNoExpand(query3.queryExecution.executedPlan) + } + } + test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType]): Unit = { withTempView("testLimit") { @@ -194,7 +215,19 @@ class PlannerSuite extends SharedSQLContext { test("CollectLimit can appear in the middle of a plan when caching is used") { val query = testData.select('key, 'value).limit(2).cache() val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] - assert(planned.child.isInstanceOf[CollectLimitExec]) + assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) + } + + test("TakeOrderedAndProjectExec appears only when number of limit is below the threshold.") { + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1000") { + val query0 = testData.select('value).orderBy('key).limit(100) + val planned0 = query0.queryExecution.executedPlan + assert(planned0.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) + + val query1 = testData.select('value).orderBy('key).limit(2000) + val planned1 = query1.queryExecution.executedPlan + assert(planned1.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isEmpty) + } } test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { @@ -621,6 +654,115 @@ class PlannerSuite extends SharedSQLContext { requiredOrdering = Seq(orderingA, orderingB), shouldHaveSort = true) } + + test("SPARK-24242: RangeExec should have correct output ordering and partitioning") { + val df = spark.range(10) + val rangeExec = df.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + val range = df.queryExecution.optimizedPlan.collect { + case r: Range => r + } + assert(rangeExec.head.outputOrdering == range.head.outputOrdering) + assert(rangeExec.head.outputPartitioning == + RangePartitioning(rangeExec.head.outputOrdering, df.rdd.getNumPartitions)) + + val rangeInOnePartition = spark.range(1, 10, 1, 1) + val rangeExecInOnePartition = rangeInOnePartition.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + assert(rangeExecInOnePartition.head.outputPartitioning == SinglePartition) + + val rangeInZeroPartition = spark.range(-10, -9, -20, 1) + val rangeExecInZeroPartition = rangeInZeroPartition.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) + } + + test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") { + val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA), + outputPartitioning = HashPartitioning(exprA :: exprA :: Nil, 5)) + val plan2 = DummySparkPlan(outputOrdering = Seq(orderingB), + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + val smjExec = SortMergeJoinExec( + exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => + assert(leftKeys == Seq(exprA, exprA)) + assert(rightKeys == Seq(exprB, exprC)) + case _ => fail() + } + } + + test("SPARK-24500: create union with stream of children") { + val df = Union(Stream( + Range(1, 1, 1, 1), + Range(1, 2, 1, 1))) + df.queryExecution.executedPlan.execute() + } + + test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " + + "and InMemoryTableScanExec") { + def checkOutputPartitioningRewrite( + plans: Seq[SparkPlan], + expectedPartitioningClass: Class[_]): Unit = { + assert(plans.size == 1) + val plan = plans.head + val partitioning = plan.outputPartitioning + assert(partitioning.getClass == expectedPartitioningClass) + val partitionedAttrs = partitioning.asInstanceOf[Expression].references + assert(partitionedAttrs.subsetOf(plan.outputSet)) + } + + def checkReusedExchangeOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val reusedExchange = df.queryExecution.executedPlan.collect { + case r: ReusedExchangeExec => r + } + checkOutputPartitioningRewrite(reusedExchange, expectedPartitioningClass) + } + + def checkInMemoryTableScanOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val inMemoryScan = df.queryExecution.executedPlan.collect { + case m: InMemoryTableScanExec => m + } + checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass) + } + + // ReusedExchange is HashPartitioning + val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + checkReusedExchangeOutputPartitioningRewrite(df1.union(df2), classOf[HashPartitioning]) + + // ReusedExchange is RangePartitioning + val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + checkReusedExchangeOutputPartitioningRewrite(df3.union(df4), classOf[RangePartitioning]) + + // InMemoryTableScan is HashPartitioning + Seq(1 -> "a").toDF("i", "j").repartition($"i").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").repartition($"i"), classOf[HashPartitioning]) + + // InMemoryTableScan is RangePartitioning + spark.range(1, 100, 1, 10).toDF().persist() + checkInMemoryTableScanOutputPartitioningRewrite( + spark.range(1, 100, 1, 10).toDF(), classOf[RangePartitioning]) + + // InMemoryTableScan is PartitioningCollection + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m"), + classOf[PartitioningCollection]) + } + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 3e31d22e15c0e..5c15ecd42fa0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import scala.util.{Random, Try} import scala.util.control.NonFatal +import org.mockito.Mockito._ import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} @@ -54,6 +55,8 @@ class UnsafeFixedWidthAggregationMapSuite private var memoryManager: TestMemoryManager = null private var taskMemoryManager: TaskMemoryManager = null + private var taskContext: TaskContext = null + def testWithMemoryLeakDetection(name: String)(f: => Unit) { def cleanup(): Unit = { if (taskMemoryManager != null) { @@ -67,6 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + taskContext = mock(classOf[TaskContext]) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -111,7 +116,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity, PAGE_SIZE_BYTES ) @@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity PAGE_SIZE_BYTES ) @@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -176,7 +181,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -223,7 +228,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -263,7 +268,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, StructType(Nil), StructType(Nil), - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -307,7 +312,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) @@ -344,7 +349,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index bf588d3bb7841..c882a9dd2148c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -231,7 +231,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap` // which has duplicated keys and the number of entries exceeds its capacity. try { - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, null, null)) + val context = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties(), null) + TaskContext.setTaskContext(context) new UnsafeKVExternalSorter( schema, schema, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index a3ae93810aa3c..d305ce3e698ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -21,15 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File} import java.util.Properties import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.Utils import org.apache.spark.util.collection.ExternalSorter /** @@ -43,7 +41,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea } } -class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { +class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { val converter = unsafeRowConverter(schema) @@ -58,7 +56,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } test("toUnsafeRow() test helper method") { - // This currently doesnt work because the generic getter throws an exception. + // This currently doesn't work because the generic getter throws an exception. val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) assert(row.getString(0) === unsafeRow.getUTF8String(0).toString) @@ -97,59 +95,43 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } test("SPARK-10466: external sorter spilling with unsafe row serializer") { - var sc: SparkContext = null - var outputFile: File = null - val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten - Utils.tryWithSafeFinally { - val conf = new SparkConf() - .set("spark.shuffle.spill.initialMemoryThreshold", "1") - .set("spark.shuffle.sort.bypassMergeThreshold", "0") - .set("spark.testing.memory", "80000") - - sc = new SparkContext("local", "test", conf) - outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") - // prepare data - val converter = unsafeRowConverter(Array(IntegerType)) - val data = (1 to 10000).iterator.map { i => - (i, converter(Row(i))) - } - val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) - - val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( - taskContext, - partitioner = Some(new HashPartitioner(10)), - serializer = new UnsafeRowSerializer(numFields = 1)) - - // Ensure we spilled something and have to merge them later - assert(sorter.numSpills === 0) - sorter.insertAll(data) - assert(sorter.numSpills > 0) + val conf = new SparkConf() + .set("spark.shuffle.spill.initialMemoryThreshold", "1") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + .set("spark.testing.memory", "80000") + spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() + val outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + outputFile.deleteOnExit() + // prepare data + val converter = unsafeRowConverter(Array(IntegerType)) + val data = (1 to 10000).iterator.map { i => + (i, converter(Row(i))) + } + val taskMemoryManager = new TaskMemoryManager(spark.sparkContext.env.memoryManager, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) - // Merging spilled files should not throw assertion error - sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) - } { - // Clean up - if (sc != null) { - sc.stop() - } + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + taskContext, + partitioner = Some(new HashPartitioner(10)), + serializer = new UnsafeRowSerializer(numFields = 1)) - // restore the spark env - SparkEnv.set(oldEnv) + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) - if (outputFile != null) { - outputFile.delete() - } - } + // Merging spilled files should not throw assertion error + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) } test("SPARK-10403: unsafe row serializer with SortShuffleManager") { val conf = new SparkConf().set("spark.shuffle.manager", "sort") - sc = new SparkContext("local", "test", conf) + spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) - val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))) - .asInstanceOf[RDD[Product2[Int, InternalRow]]] + val rowsRDD = spark.sparkContext.parallelize( + Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)) + ).asInstanceOf[RDD[Product2[Int, InternalRow]]] val dependency = new ShuffleDependency[Int, InternalRow, InternalRow]( rowsRDD, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9180a22c260f1..b714dcd5269fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -51,12 +51,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } test("Aggregate with grouping keys should be included in WholeStageCodegen") { - val df = spark.range(3).groupBy("id").count().orderBy("id") + val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) - assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) + assert(df.collect() === Array(Row(0, 1), Row(2, 1), Row(4, 1))) } test("BroadcastHashJoin should be included in WholeStageCodegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala new file mode 100644 index 0000000000000..8711f5a8fa1ce --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -0,0 +1,836 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnVector +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure data source read performance. + * To run this: + * spark-submit --class + */ +object DataSourceReadBenchmark { + val conf = new SparkConf() + .setAppName("DataSourceReadBenchmark") + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("spark.ui.enabled", "false") + + val spark = SparkSession.builder.config(conf).getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") + spark.conf.set(SQLConf.ORC_COPY_BATCH_TO_SPARK.key, "false") + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { + val testDf = if (partition.isDefined) { + df.write.partitionBy(partition.get) + } else { + df.write + } + + saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv") + saveAsJsonTable(testDf, dir.getCanonicalPath + "/json") + saveAsParquetTable(testDf, dir.getCanonicalPath + "/parquet") + saveAsOrcTable(testDf, dir.getCanonicalPath + "/orc") + } + + private def saveAsCsvTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "gzip").option("header", true).csv(dir) + spark.read.option("header", true).csv(dir).createOrReplaceTempView("csvTable") + } + + private def saveAsJsonTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "gzip").json(dir) + spark.read.json(dir).createOrReplaceTempView("jsonTable") + } + + private def saveAsParquetTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "snappy").parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("parquetTable") + } + + private def saveAsOrcTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "snappy").orc(dir) + spark.read.orc(dir).createOrReplaceTempView("orcTable") + } + + def numericScanBenchmark(values: Int, dataType: DataType): Unit = { + // Benchmarks running through spark sql. + val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + + // Benchmarks driving reader component directly. + val parquetReaderBenchmark = new Benchmark( + s"Parquet Reader Single ${dataType.sql} Column Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + + sqlBenchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(id) from csvTable").collect() + } + + sqlBenchmark.addCase("SQL Json") { _ => + spark.sql("select sum(id) from jsonTable").collect() + } + + sqlBenchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(id) from parquetTable").collect() + } + + sqlBenchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(id) from parquetTable").collect() + } + } + + sqlBenchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + + sqlBenchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + sqlBenchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 22964 / 23096 0.7 1460.0 1.0X + SQL Json 8469 / 8593 1.9 538.4 2.7X + SQL Parquet Vectorized 164 / 177 95.8 10.4 139.9X + SQL Parquet MR 1687 / 1706 9.3 107.2 13.6X + SQL ORC Vectorized 191 / 197 82.3 12.2 120.2X + SQL ORC Vectorized with copy 215 / 219 73.2 13.7 106.9X + SQL ORC MR 1392 / 1412 11.3 88.5 16.5X + + + SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 24090 / 24097 0.7 1531.6 1.0X + SQL Json 8791 / 8813 1.8 558.9 2.7X + SQL Parquet Vectorized 204 / 212 77.0 13.0 117.9X + SQL Parquet MR 1813 / 1850 8.7 115.3 13.3X + SQL ORC Vectorized 226 / 230 69.7 14.4 106.7X + SQL ORC Vectorized with copy 295 / 298 53.3 18.8 81.6X + SQL ORC MR 1526 / 1549 10.3 97.1 15.8X + + + SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 25637 / 25791 0.6 1629.9 1.0X + SQL Json 9532 / 9570 1.7 606.0 2.7X + SQL Parquet Vectorized 181 / 191 86.8 11.5 141.5X + SQL Parquet MR 2210 / 2227 7.1 140.5 11.6X + SQL ORC Vectorized 309 / 317 50.9 19.6 83.0X + SQL ORC Vectorized with copy 316 / 322 49.8 20.1 81.2X + SQL ORC MR 1650 / 1680 9.5 104.9 15.5X + + + SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 31617 / 31764 0.5 2010.1 1.0X + SQL Json 12440 / 12451 1.3 790.9 2.5X + SQL Parquet Vectorized 284 / 315 55.4 18.0 111.4X + SQL Parquet MR 2382 / 2390 6.6 151.5 13.3X + SQL ORC Vectorized 398 / 403 39.5 25.3 79.5X + SQL ORC Vectorized with copy 410 / 413 38.3 26.1 77.1X + SQL ORC MR 1783 / 1813 8.8 113.4 17.7X + + + SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 26679 / 26742 0.6 1696.2 1.0X + SQL Json 12490 / 12541 1.3 794.1 2.1X + SQL Parquet Vectorized 174 / 183 90.4 11.1 153.3X + SQL Parquet MR 2201 / 2223 7.1 140.0 12.1X + SQL ORC Vectorized 415 / 429 37.9 26.4 64.3X + SQL ORC Vectorized with copy 422 / 428 37.2 26.9 63.2X + SQL ORC MR 1767 / 1773 8.9 112.3 15.1X + + + SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 34223 / 34324 0.5 2175.8 1.0X + SQL Json 17784 / 17785 0.9 1130.7 1.9X + SQL Parquet Vectorized 277 / 283 56.7 17.6 123.4X + SQL Parquet MR 2356 / 2386 6.7 149.8 14.5X + SQL ORC Vectorized 533 / 536 29.5 33.9 64.2X + SQL ORC Vectorized with copy 541 / 546 29.1 34.4 63.3X + SQL ORC MR 2166 / 2177 7.3 137.7 15.8X + */ + sqlBenchmark.run() + + // Driving the parquet reader in batch mode directly. + val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize + parquetReaderBenchmark.addCase("ParquetReader Vectorized") { _ => + var longSum = 0L + var doubleSum = 0.0 + val aggregateValue: (ColumnVector, Int) => Unit = dataType match { + case ByteType => (col: ColumnVector, i: Int) => longSum += col.getByte(i) + case ShortType => (col: ColumnVector, i: Int) => longSum += col.getShort(i) + case IntegerType => (col: ColumnVector, i: Int) => longSum += col.getInt(i) + case LongType => (col: ColumnVector, i: Int) => longSum += col.getLong(i) + case FloatType => (col: ColumnVector, i: Int) => doubleSum += col.getFloat(i) + case DoubleType => (col: ColumnVector, i: Int) => doubleSum += col.getDouble(i) + } + + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + val col = batch.column(0) + while (reader.nextBatch()) { + val numRows = batch.numRows() + var i = 0 + while (i < numRows) { + if (!col.isNullAt(i)) aggregateValue(col, i) + i += 1 + } + } + } finally { + reader.close() + } + } + } + + // Decoding in vectorized but having the reader return rows. + parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => + var longSum = 0L + var doubleSum = 0.0 + val aggregateValue: (InternalRow) => Unit = dataType match { + case ByteType => (col: InternalRow) => longSum += col.getByte(0) + case ShortType => (col: InternalRow) => longSum += col.getShort(0) + case IntegerType => (col: InternalRow) => longSum += col.getInt(0) + case LongType => (col: InternalRow) => longSum += col.getLong(0) + case FloatType => (col: InternalRow) => doubleSum += col.getFloat(0) + case DoubleType => (col: InternalRow) => doubleSum += col.getDouble(0) + } + + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val it = batch.rowIterator() + while (it.hasNext) { + val record = it.next() + if (!record.isNullAt(0)) aggregateValue(record) + } + } + } finally { + reader.close() + } + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 198 / 202 79.4 12.6 1.0X + ParquetReader Vectorized -> Row 119 / 121 132.3 7.6 1.7X + + + Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 282 / 287 55.8 17.9 1.0X + ParquetReader Vectorized -> Row 246 / 247 64.0 15.6 1.1X + + + Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 258 / 262 60.9 16.4 1.0X + ParquetReader Vectorized -> Row 259 / 260 60.8 16.5 1.0X + + + Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 361 / 369 43.6 23.0 1.0X + ParquetReader Vectorized -> Row 361 / 371 43.6 22.9 1.0X + + + Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 253 / 261 62.2 16.1 1.0X + ParquetReader Vectorized -> Row 254 / 256 61.9 16.2 1.0X + + + Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 357 / 364 44.0 22.7 1.0X + ParquetReader Vectorized -> Row 358 / 366 44.0 22.7 1.0X + */ + parquetReaderBenchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Int and String Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(c1), sum(length(c2)) from csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(c1), sum(length(c2)) from jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(c1), sum(length(c2)) from parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(c1), sum(length(c2)) from parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 27145 / 27158 0.4 2588.7 1.0X + SQL Json 12969 / 13337 0.8 1236.8 2.1X + SQL Parquet Vectorized 2419 / 2448 4.3 230.7 11.2X + SQL Parquet MR 4631 / 4633 2.3 441.7 5.9X + SQL ORC Vectorized 2412 / 2465 4.3 230.0 11.3X + SQL ORC Vectorized with copy 2633 / 2675 4.0 251.1 10.3X + SQL ORC MR 4280 / 4350 2.4 408.2 6.3X + */ + benchmark.run() + } + } + } + + def repeatedStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Repeated String", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("select cast((value % 200) + 10000 as STRING) as c1 from t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(length(c1)) from csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(length(c1)) from jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(length(c1)) from parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c1)) from parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("select sum(length(c1)) from orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("select sum(length(c1)) from orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c1)) from orcTable").collect() + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 17345 / 17424 0.6 1654.1 1.0X + SQL Json 8639 / 8664 1.2 823.9 2.0X + SQL Parquet Vectorized 839 / 854 12.5 80.0 20.7X + SQL Parquet MR 1771 / 1775 5.9 168.9 9.8X + SQL ORC Vectorized 550 / 569 19.1 52.4 31.6X + SQL ORC Vectorized with copy 785 / 849 13.4 74.9 22.1X + SQL ORC MR 2168 / 2202 4.8 206.7 8.0X + */ + benchmark.run() + } + } + } + + def partitionTableScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Partitioned Table", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) + + benchmark.addCase("Data column - CSV") { _ => + spark.sql("select sum(id) from csvTable").collect() + } + + benchmark.addCase("Data column - Json") { _ => + spark.sql("select sum(id) from jsonTable").collect() + } + + benchmark.addCase("Data column - Parquet Vectorized") { _ => + spark.sql("select sum(id) from parquetTable").collect() + } + + benchmark.addCase("Data column - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(id) from parquetTable").collect() + } + } + + benchmark.addCase("Data column - ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + + benchmark.addCase("Data column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Data column - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Partition column - CSV") { _ => + spark.sql("select sum(p) from csvTable").collect() + } + + benchmark.addCase("Partition column - Json") { _ => + spark.sql("select sum(p) from jsonTable").collect() + } + + benchmark.addCase("Partition column - Parquet Vectorized") { _ => + spark.sql("select sum(p) from parquetTable").collect() + } + + benchmark.addCase("Partition column - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(p) from parquetTable").collect() + } + } + + benchmark.addCase("Partition column - ORC Vectorized") { _ => + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + + benchmark.addCase("Partition column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + } + + benchmark.addCase("Partition column - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + } + + benchmark.addCase("Both columns - CSV") { _ => + spark.sql("select sum(p), sum(id) from csvTable").collect() + } + + benchmark.addCase("Both columns - Json") { _ => + spark.sql("select sum(p), sum(id) from jsonTable").collect() + } + + benchmark.addCase("Both columns - Parquet Vectorized") { _ => + spark.sql("select sum(p), sum(id) from parquetTable").collect() + } + + benchmark.addCase("Both columns - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(p), sum(id) from parquetTable").collect + } + } + + benchmark.addCase("Both columns - ORC Vectorized") { _ => + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + + benchmark.addCase("Both column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Both columns - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Data column - CSV 32613 / 32841 0.5 2073.4 1.0X + Data column - Json 13343 / 13469 1.2 848.3 2.4X + Data column - Parquet Vectorized 302 / 318 52.1 19.2 108.0X + Data column - Parquet MR 2908 / 2924 5.4 184.9 11.2X + Data column - ORC Vectorized 412 / 425 38.1 26.2 79.1X + Data column - ORC Vectorized with copy 442 / 446 35.6 28.1 73.8X + Data column - ORC MR 2390 / 2396 6.6 152.0 13.6X + Partition column - CSV 9626 / 9683 1.6 612.0 3.4X + Partition column - Json 10909 / 10923 1.4 693.6 3.0X + Partition column - Parquet Vectorized 69 / 76 228.4 4.4 473.6X + Partition column - Parquet MR 1898 / 1933 8.3 120.7 17.2X + Partition column - ORC Vectorized 67 / 74 236.0 4.2 489.4X + Partition column - ORC Vectorized with copy 65 / 72 241.9 4.1 501.6X + Partition column - ORC MR 1743 / 1749 9.0 110.8 18.7X + Both columns - CSV 35523 / 35552 0.4 2258.5 0.9X + Both columns - Json 13676 / 13681 1.2 869.5 2.4X + Both columns - Parquet Vectorized 317 / 326 49.5 20.2 102.7X + Both columns - Parquet MR 3333 / 3336 4.7 211.9 9.8X + Both columns - ORC Vectorized 441 / 446 35.6 28.1 73.9X + Both column - ORC Vectorized with copy 517 / 524 30.4 32.9 63.1X + Both columns - ORC MR 2574 / 2577 6.1 163.6 12.7X + */ + benchmark.run() + } + } + } + + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + val benchmark = new Benchmark("String with Nulls Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql( + s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(length(c2)) from csvTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(length(c2)) from jsonTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(length(c2)) from parquetTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c2)) from parquetTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + } + + val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize + benchmark.addCase("ParquetReader Vectorized") { num => + var sum = 0 + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val rowIterator = batch.rowIterator() + while (rowIterator.hasNext) { + val row = rowIterator.next() + val value = row.getUTF8String(0) + if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() + } + } + } finally { + reader.close() + } + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 14875 / 14920 0.7 1418.6 1.0X + SQL Json 10974 / 10992 1.0 1046.5 1.4X + SQL Parquet Vectorized 1711 / 1750 6.1 163.2 8.7X + SQL Parquet MR 3838 / 3884 2.7 366.0 3.9X + ParquetReader Vectorized 1155 / 1168 9.1 110.2 12.9X + SQL ORC Vectorized 1341 / 1380 7.8 127.9 11.1X + SQL ORC Vectorized with copy 1659 / 1716 6.3 158.2 9.0X + SQL ORC MR 3594 / 3634 2.9 342.7 4.1X + + + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 17219 / 17264 0.6 1642.1 1.0X + SQL Json 8843 / 8864 1.2 843.3 1.9X + SQL Parquet Vectorized 1169 / 1178 9.0 111.4 14.7X + SQL Parquet MR 2676 / 2697 3.9 255.2 6.4X + ParquetReader Vectorized 1068 / 1071 9.8 101.8 16.1X + SQL ORC Vectorized 1319 / 1319 7.9 125.8 13.1X + SQL ORC Vectorized with copy 1638 / 1639 6.4 156.2 10.5X + SQL ORC MR 3230 / 3257 3.2 308.1 5.3X + + + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 13976 / 14053 0.8 1332.8 1.0X + SQL Json 5166 / 5176 2.0 492.6 2.7X + SQL Parquet Vectorized 274 / 282 38.2 26.2 50.9X + SQL Parquet MR 1553 / 1555 6.8 148.1 9.0X + ParquetReader Vectorized 241 / 246 43.5 23.0 57.9X + SQL ORC Vectorized 476 / 479 22.0 45.4 29.3X + SQL ORC Vectorized with copy 584 / 588 17.9 55.7 23.9X + SQL ORC MR 1720 / 1734 6.1 164.1 8.1X + */ + benchmark.run() + } + } + } + + def columnsBenchmark(values: Int, width: Int): Unit = { + val benchmark = new Benchmark(s"Single Column Scan from $width columns", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + val middle = width / 2 + val selectExpr = (1 to width).map(i => s"value as c$i") + spark.range(values).map(_ => Random.nextLong).toDF() + .selectExpr(selectExpr: _*).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT * FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql(s"SELECT sum(c$middle) FROM csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql(s"SELECT sum(c$middle) FROM jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 3478 / 3481 0.3 3316.4 1.0X + SQL Json 2646 / 2654 0.4 2523.6 1.3X + SQL Parquet Vectorized 67 / 72 15.8 63.5 52.2X + SQL Parquet MR 207 / 214 5.1 197.6 16.8X + SQL ORC Vectorized 69 / 76 15.2 66.0 50.3X + SQL ORC Vectorized with copy 70 / 76 15.0 66.5 49.9X + SQL ORC MR 299 / 303 3.5 285.1 11.6X + + + Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 9214 / 9236 0.1 8786.7 1.0X + SQL Json 9943 / 9978 0.1 9482.7 0.9X + SQL Parquet Vectorized 77 / 86 13.6 73.3 119.8X + SQL Parquet MR 229 / 235 4.6 218.6 40.2X + SQL ORC Vectorized 84 / 96 12.5 80.0 109.9X + SQL ORC Vectorized with copy 83 / 91 12.6 79.4 110.7X + SQL ORC MR 843 / 854 1.2 804.0 10.9X + + + Single Column Scan from 100 columns Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 16503 / 16622 0.1 15738.9 1.0X + SQL Json 19109 / 19184 0.1 18224.2 0.9X + SQL Parquet Vectorized 99 / 108 10.6 94.3 166.8X + SQL Parquet MR 253 / 264 4.1 241.6 65.1X + SQL ORC Vectorized 107 / 114 9.8 101.6 154.8X + SQL ORC Vectorized with copy 107 / 118 9.8 102.1 154.1X + SQL ORC MR 1526 / 1529 0.7 1455.3 10.8X + */ + benchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => + numericScanBenchmark(1024 * 1024 * 15, dataType) + } + intStringScanBenchmark(1024 * 1024 * 10) + repeatedStringScanBenchmark(1024 * 1024 * 10) + partitionTableScanBenchmark(1024 * 1024 * 15) + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + for (columnWidth <- List(10, 50, 100)) { + columnsBenchmark(1024 * 1024 * 1, columnWidth) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala new file mode 100644 index 0000000000000..2d2cdebd067c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Benchmark + +/** + * Benchmark to measure data source write performance. + * By default it measures 4 data source format: Parquet, ORC, JSON, CSV: + * spark-submit --class + * To measure specified formats, run it with arguments: + * spark-submit --class format1 [format2] [...] + */ +object DataSourceWriteBenchmark { + val conf = new SparkConf() + .setAppName("DataSourceWriteBenchmark") + .setIfMissing("spark.master", "local[1]") + .set("spark.sql.parquet.compression.codec", "snappy") + .set("spark.sql.orc.compression.codec", "snappy") + + val spark = SparkSession.builder.config(conf).getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + val tempTable = "temp" + val numRows = 1024 * 1024 * 15 + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + def writeNumeric(table: String, format: String, benchmark: Benchmark, dataType: String): Unit = { + spark.sql(s"create table $table(id $dataType) using $format") + benchmark.addCase(s"Output Single $dataType Column") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS $dataType) AS c1 FROM $tempTable") + } + } + + def writeIntString(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(c1 INT, c2 STRING) USING $format") + benchmark.addCase("Output Int and String Column") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " + + s"c1, CAST(id AS STRING) AS c2 FROM $tempTable") + } + } + + def writePartition(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(p INT, id INT) USING $format PARTITIONED BY (p)") + benchmark.addCase("Output Partitions") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS id," + + s" CAST(id % 2 AS INT) AS p FROM $tempTable") + } + } + + def writeBucket(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(c1 INT, c2 INT) USING $format CLUSTERED BY (c2) INTO 2 BUCKETS") + benchmark.addCase("Output Buckets") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " + + s"c1, CAST(id AS INT) AS c2 FROM $tempTable") + } + } + + def main(args: Array[String]): Unit = { + val tableInt = "tableInt" + val tableDouble = "tableDouble" + val tableIntString = "tableIntString" + val tablePartition = "tablePartition" + val tableBucket = "tableBucket" + val formats: Seq[String] = if (args.isEmpty) { + Seq("Parquet", "ORC", "JSON", "CSV") + } else { + args + } + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + Parquet writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1815 / 1932 8.7 115.4 1.0X + Output Single Double Column 1877 / 1878 8.4 119.3 1.0X + Output Int and String Column 6265 / 6543 2.5 398.3 0.3X + Output Partitions 4067 / 4457 3.9 258.6 0.4X + Output Buckets 5608 / 5820 2.8 356.6 0.3X + + ORC writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1201 / 1239 13.1 76.3 1.0X + Output Single Double Column 1542 / 1600 10.2 98.0 0.8X + Output Int and String Column 6495 / 6580 2.4 412.9 0.2X + Output Partitions 3648 / 3842 4.3 231.9 0.3X + Output Buckets 5022 / 5145 3.1 319.3 0.2X + + JSON writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1988 / 2093 7.9 126.4 1.0X + Output Single Double Column 2854 / 2911 5.5 181.4 0.7X + Output Int and String Column 6467 / 6653 2.4 411.1 0.3X + Output Partitions 4548 / 5055 3.5 289.1 0.4X + Output Buckets 5664 / 5765 2.8 360.1 0.4X + + CSV writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 3025 / 3190 5.2 192.3 1.0X + Output Single Double Column 3575 / 3634 4.4 227.3 0.8X + Output Int and String Column 7313 / 7399 2.2 464.9 0.4X + Output Partitions 5105 / 5190 3.1 324.6 0.6X + Output Buckets 6986 / 6992 2.3 444.1 0.4X + */ + withTempTable(tempTable) { + spark.range(numRows).createOrReplaceTempView(tempTable) + formats.foreach { format => + withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) { + val benchmark = new Benchmark(s"$format writer benchmark", numRows) + writeNumeric(tableInt, format, benchmark, "Int") + writeNumeric(tableDouble, format, benchmark, "Double") + writeIntString(tableIntString, format, benchmark) + writePartition(tablePartition, format, benchmark) + writeBucket(tableBucket, format, benchmark) + benchmark.run() + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala new file mode 100644 index 0000000000000..567a8ebf9d102 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.io.{File, FileOutputStream, OutputStream} + +import scala.util.{Random, Try} + +import org.scalatest.{BeforeAndAfterEachTestData, Suite, TestData} + +import org.apache.spark.SparkConf +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType +import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType} +import org.apache.spark.util.{Benchmark, Utils} + +/** + * Benchmark to measure read performance with Filter pushdown. + * To run this: + * build/sbt "sql/test-only *FilterPushdownBenchmark" + * + * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". + */ +class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfterEachTest { + private val conf = new SparkConf() + .setAppName(this.getClass.getSimpleName) + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("spark.ui.enabled", "false") + .setIfMissing("orc.compression", "snappy") + .setIfMissing("spark.sql.parquet.compression.codec", "snappy") + + private val numRows = 1024 * 1024 * 15 + private val width = 5 + private val mid = numRows / 2 + private val blockSize = 1048576 + + private val spark = SparkSession.builder().config(conf).getOrCreate() + + private var out: OutputStream = _ + + override def beforeAll() { + super.beforeAll() + out = new FileOutputStream(new File("benchmarks/FilterPushdownBenchmark-results.txt")) + } + + override def beforeEach(td: TestData) { + super.beforeEach(td) + val separator = "=" * 96 + val testHeader = (separator + '\n' + td.name + '\n' + separator + '\n' + '\n').getBytes + out.write(testHeader) + } + + override def afterEach(td: TestData) { + out.write('\n') + super.afterEach(td) + } + + override def afterAll() { + try { + out.close() + } finally { + super.afterAll() + } + } + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private def prepareTable( + dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = { + import spark.implicits._ + val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") + val valueCol = if (useStringForValue) { + monotonically_increasing_id().cast("string") + } else { + monotonically_increasing_id() + } + val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) + .withColumn("value", valueCol) + .sort("value") + + saveAsTable(df, dir) + } + + private def prepareStringDictTable( + dir: File, numRows: Int, numDistinctValues: Int, width: Int): Unit = { + val selectExpr = (0 to width).map { + case 0 => s"CAST(id % $numDistinctValues AS STRING) AS value" + case i => s"CAST(rand() AS STRING) c$i" + } + val df = spark.range(numRows).selectExpr(selectExpr: _*).sort("value") + + saveAsTable(df, dir) + } + + private def saveAsTable(df: DataFrame, dir: File): Unit = { + val orcPath = dir.getCanonicalPath + "/orc" + val parquetPath = dir.getCanonicalPath + "/parquet" + + // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) + df.write.mode("overwrite") + .option("orc.dictionary.key.threshold", 1.0) + .option("orc.stripe.size", blockSize).orc(orcPath) + spark.read.orc(orcPath).createOrReplaceTempView("orcTable") + + df.write.mode("overwrite") + .option("parquet.block.size", blockSize).parquet(parquetPath) + spark.read.parquet(parquetPath).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5, output = Some(out)) + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() + } + } + } + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() + } + } + } + + benchmark.run() + } + + private def runIntBenchmark(numRows: Int, width: Int, mid: Int): Unit = { + Seq("value IS NULL", s"$mid < value AND value < $mid").foreach { whereExpr => + val title = s"Select 0 int row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"value = $mid", + s"value <=> $mid", + s"$mid <= value AND value <= $mid", + s"${mid - 1} < value AND value < ${mid + 1}" + ).foreach { whereExpr => + val title = s"Select 1 int row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% int rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } + + Seq("value IS NOT NULL", "value > -1", "value != -1").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all int rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + + private def runStringBenchmark( + numRows: Int, width: Int, searchValue: Int, colType: String): Unit = { + Seq("value IS NULL", s"'$searchValue' < value AND value < '$searchValue'") + .foreach { whereExpr => + val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"value = '$searchValue'", + s"value <=> '$searchValue'", + s"'$searchValue' <= value AND value <= '$searchValue'" + ).foreach { whereExpr => + val title = s"Select 1 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + + Seq("value IS NOT NULL").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all $colType rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + + ignore("Pushdown for many distinct value case") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + Seq(true, false).foreach { useStringForValue => + prepareTable(dir, numRows, width, useStringForValue) + if (useStringForValue) { + runStringBenchmark(numRows, width, mid, "string") + } else { + runIntBenchmark(numRows, width, mid) + } + } + } + } + } + + ignore("Pushdown for few distinct value case (use dictionary encoding)") { + withTempPath { dir => + val numDistinctValues = 200 + + withTempTable("orcTable", "patquetTable") { + prepareStringDictTable(dir, numRows, numDistinctValues, width) + runStringBenchmark(numRows, width, numDistinctValues / 2, "distinct string") + } + } + } + + ignore("Pushdown benchmark for StringStartsWith") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width, true) + Seq( + "value like '10%'", + "value like '1000%'", + s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'" + ).foreach { whereExpr => + val title = s"StringStartsWith filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) + } + } + } + } + + ignore(s"Pushdown benchmark for ${DecimalType.simpleString}") { + withTempPath { dir => + Seq( + s"decimal(${Decimal.MAX_INT_DIGITS}, 2)", + s"decimal(${Decimal.MAX_LONG_DIGITS}, 2)", + s"decimal(${DecimalType.MAX_PRECISION}, 2)" + ).foreach { dt => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", monotonically_increasing_id().cast(dt)) + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = $mid").foreach { whereExpr => + val title = s"Select 1 $dt row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% $dt rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } + } + } + } + } + + ignore("Pushdown benchmark for InSet -> InFilters") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width, false) + Seq(5, 10, 50, 100).foreach { count => + Seq(10, 50, 90).foreach { distribution => + val filter = + Range(0, count).map(r => scala.util.Random.nextInt(numRows * distribution / 100)) + val whereExpr = s"value in(${filter.mkString(",")})" + val title = s"InSet -> InFilters (values count: $count, distribution: $distribution)" + filterPushDownBenchmark(numRows, title, whereExpr) + } + } + } + } + } + + ignore(s"Pushdown benchmark for ${ByteType.simpleString}") { + withTempPath { dir => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", (monotonically_increasing_id() % Byte.MaxValue).cast(ByteType)) + .orderBy("value") + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST(${Byte.MaxValue / 2} AS ${ByteType.simpleString})") + .foreach { whereExpr => + val title = s"Select 1 ${ByteType.simpleString} row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% ${ByteType.simpleString} rows " + + s"(value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString}))", + s"value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString})", + selectExpr + ) + } + } + } + } + + ignore(s"Pushdown benchmark for Timestamp") { + withTempPath { dir => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> true.toString) { + ParquetOutputTimestampType.values.toSeq.map(_.toString).foreach { fileType => + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> fileType) { + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", monotonically_increasing_id().cast(TimestampType)) + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST($mid AS timestamp)").foreach { whereExpr => + val title = s"Select 1 timestamp stored as $fileType row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% timestamp stored as $fileType rows " + + s"(value < CAST(${numRows * percent / 100} AS timestamp))", + s"value < CAST(${numRows * percent / 100} as timestamp)", + selectExpr + ) + } + } + } + } + } + } + } +} + +trait BenchmarkBeforeAndAfterEachTest extends BeforeAndAfterEachTestData { this: Suite => + + override def beforeEach(td: TestData) { + super.beforeEach(td) + } + + override def afterEach(td: TestData) { + super.afterEach(td) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 69247d7f4e9aa..abe61a2c2b9c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -58,10 +58,13 @@ object TPCDSQueryBenchmark extends Logging { }.toMap } - def tpcdsAll(dataLocation: String, queries: Seq[String]): Unit = { - val tableSizes = setupTables(dataLocation) + def runTpcdsQueries( + queryLocation: String, + queries: Seq[String], + tableSizes: Map[String, Long], + nameSuffix: String = ""): Unit = { queries.foreach { name => - val queryString = resourceToString(s"tpcds/$name.sql", + val queryString = resourceToString(s"$queryLocation/$name.sql", classLoader = Thread.currentThread().getContextClassLoader) // This is an indirect hack to estimate the size of each query's input by traversing the @@ -78,7 +81,7 @@ object TPCDSQueryBenchmark extends Logging { } val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum val benchmark = new Benchmark(s"TPCDS Snappy", numRows, 5) - benchmark.addCase(name) { i => + benchmark.addCase(s"$name$nameSuffix") { _ => spark.sql(queryString).collect() } logInfo(s"\n\n===== TPCDS QUERY BENCHMARK OUTPUT FOR $name =====\n") @@ -87,10 +90,20 @@ object TPCDSQueryBenchmark extends Logging { } } + def filterQueries( + origQueries: Seq[String], + args: TPCDSQueryBenchmarkArguments): Seq[String] = { + if (args.queryFilter.nonEmpty) { + origQueries.filter(args.queryFilter.contains) + } else { + origQueries + } + } + def main(args: Array[String]): Unit = { val benchmarkArgs = new TPCDSQueryBenchmarkArguments(args) - // List of all TPC-DS queries + // List of all TPC-DS v1.4 queries val tpcdsQueries = Seq( "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", @@ -103,20 +116,25 @@ object TPCDSQueryBenchmark extends Logging { "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + // This list only includes TPC-DS v2.7 queries that are different from v1.4 ones + val tpcdsQueriesV2_7 = Seq( + "q5a", "q6", "q10a", "q11", "q12", "q14", "q14a", "q18a", + "q20", "q22", "q22a", "q24", "q27a", "q34", "q35", "q35a", "q36a", "q47", "q49", + "q51a", "q57", "q64", "q67a", "q70a", "q72", "q74", "q75", "q77a", "q78", + "q80a", "q86a", "q98") + // If `--query-filter` defined, filters the queries that this option selects - val queriesToRun = if (benchmarkArgs.queryFilter.nonEmpty) { - val queries = tpcdsQueries.filter { case queryName => - benchmarkArgs.queryFilter.contains(queryName) - } - if (queries.isEmpty) { - throw new RuntimeException( - s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") - } - queries - } else { - tpcdsQueries + val queriesV1_4ToRun = filterQueries(tpcdsQueries, benchmarkArgs) + val queriesV2_7ToRun = filterQueries(tpcdsQueriesV2_7, benchmarkArgs) + + if ((queriesV1_4ToRun ++ queriesV2_7ToRun).isEmpty) { + throw new RuntimeException( + s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") } - tpcdsAll(benchmarkArgs.dataLocation, queries = queriesToRun) + val tableSizes = setupTables(benchmarkArgs.dataLocation) + runTpcdsQueries(queryLocation = "tpcds", queries = queriesV1_4ToRun, tableSizes) + runTpcdsQueries(queryLocation = "tpcds-v2.7.0", queries = queriesV2_7ToRun, tableSizes, + nameSuffix = "-v2.7") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 9b7b316211d30..efc2f20a907f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -45,8 +45,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, data.logicalPlan) - assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) - inMemoryRelation.cachedColumnBuffers.collect().head match { + assert(inMemoryRelation.cacheBuilder.cachedColumnBuffers.getStorageLevel == storageLevel) + inMemoryRelation.cacheBuilder.cachedColumnBuffers.collect().head match { case _: CachedBatch => case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}") } @@ -337,7 +337,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) + assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) } test("access primitive-type columns in CachedBatch without whole stage codegen") { @@ -503,7 +503,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { case plan: InMemoryRelation => plan }.head // InMemoryRelation's stats is file size before the underlying RDD is materialized - assert(inMemoryRelation.computeStats().sizeInBytes === 740) + assert(inMemoryRelation.computeStats().sizeInBytes === 800) // InMemoryRelation's stats is updated after materializing RDD dfFromFile.collect() @@ -516,7 +516,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats // is calculated - assert(inMemoryRelation2.computeStats().sizeInBytes === 740) + assert(inMemoryRelation2.computeStats().sizeInBytes === 800) // InMemoryRelation's stats should be updated after calculating stats of the table // clear cache to simulate a fresh environment diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 3998ceca38b30..ca95aad3976e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -441,6 +441,24 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename a managed table with existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab2"))) + try { + withTable("tab1") { + sql(s"CREATE TABLE tab1 USING $dataSource AS SELECT 1, 'a'") + tableLoc.mkdir() + val ex = intercept[AnalysisException] { + sql("ALTER TABLE tab1 RENAME TO tab2") + }.getMessage + val expectedMsg = "Can not rename the managed table('`tab1`'). The associated location" + assert(ex.contains(expectedMsg)) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + private def checkSchemaInCreatedDataSourceTable( path: File, userSpecifiedSchema: Option[String], @@ -2495,7 +2513,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { test("alter datasource table add columns - text format not supported") { withTable("t1") { - sql("CREATE TABLE t1 (c1 int) USING text") + sql("CREATE TABLE t1 (c1 string) USING text") val e = intercept[AnalysisException] { sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") }.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 8764f0c42cf9f..bceaf1a9ec061 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem} import org.apache.hadoop.mapreduce.Job -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala new file mode 100644 index 0000000000000..a39a25be262a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext + +class HadoopFileLinesReaderSuite extends SharedSQLContext { + def getLines( + path: File, + text: String, + ranges: Seq[(Long, Long)], + delimiter: Option[String] = None, + conf: Option[Configuration] = None): Seq[String] = { + val delimOpt = delimiter.map(_.getBytes(StandardCharsets.UTF_8)) + Files.write(path.toPath, text.getBytes(StandardCharsets.UTF_8)) + + val lines = ranges.map { case (start, length) => + val file = PartitionedFile(InternalRow.empty, path.getCanonicalPath, start, length) + val hadoopConf = conf.getOrElse(spark.sparkContext.hadoopConfiguration) + val reader = new HadoopFileLinesReader(file, delimOpt, hadoopConf) + + reader.map(_.toString) + }.flatten + + lines + } + + test("A split ends at the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 1), (1, 3))) + assert(lines == Seq("a", "b")) + } + } + + test("A split cuts the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 2), (2, 2))) + assert(lines == Seq("a", "b")) + } + } + + test("A split ends at the end of the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 3), (3, 1))) + assert(lines == Seq("a", "b")) + } + } + + test("A split covers two lines") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 4), (4, 1))) + assert(lines == Seq("a", "b")) + } + } + + test("A split ends at the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 1), (1, 4)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("A split slices the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 2), (2, 3)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("The first split covers the first line and the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 4), (4, 1)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("A split cuts the first line") { + withTempPath { path => + val lines = getLines(path, text = "abc,def", ranges = Seq((0, 1)), Some(",")) + assert(lines == Seq("abc")) + } + } + + test("The split cuts both lines") { + withTempPath { path => + val lines = getLines(path, text = "abc,def", ranges = Seq((2, 2)), Some(",")) + assert(lines == Seq("def")) + } + } + + test("io.file.buffer.size is less than line length") { + val conf = spark.sparkContext.hadoopConfiguration + conf.set("io.file.buffer.size", "2") + withTempPath { path => + val lines = getLines(path, text = "abcdef\n123456", ranges = Seq((4, 4), (8, 5))) + assert(lines == Seq("123456")) + } + } + + test("line cannot be longer than line.maxlength") { + val conf = spark.sparkContext.hadoopConfiguration + conf.set("mapreduce.input.linerecordreader.line.maxlength", "5") + withTempPath { path => + val lines = getLines(path, text = "abcdef\n1234", ranges = Seq((0, 15))) + assert(lines == Seq("1234")) + } + } + + test("default delimiter is 0xd or 0xa or 0xd0xa") { + withTempPath { path => + val lines = getLines(path, text = "1\r2\n3\r\n4", ranges = Seq((0, 3), (3, 5))) + assert(lines == Seq("1", "2", "3", "4")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala new file mode 100644 index 0000000000000..23c58e175fe5e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.internal.SQLConf + +/** + * Read schema suites have the following hierarchy and aims to guarantee users + * a backward-compatible read-schema change coverage on file-based data sources, and + * to prevent future regressions. + * + * ReadSchemaSuite + * -> CSVReadSchemaSuite + * -> HeaderCSVReadSchemaSuite + * + * -> JsonReadSchemaSuite + * + * -> OrcReadSchemaSuite + * -> VectorizedOrcReadSchemaSuite + * + * -> ParquetReadSchemaSuite + * -> VectorizedParquetReadSchemaSuite + * -> MergedParquetReadSchemaSuite + */ + +/** + * All file-based data sources supports column addition and removal at the end. + */ +abstract class ReadSchemaSuite + extends AddColumnTest + with HideColumnAtTheEndTest { + + var originalConf: Boolean = _ +} + +class CSVReadSchemaSuite + extends ReadSchemaSuite + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "csv" +} + +class HeaderCSVReadSchemaSuite + extends ReadSchemaSuite + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "csv" + + override val options = Map("header" -> "true") +} + +class JsonReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "json" +} + +class OrcReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "orc" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "false") + } + + override def afterAll() { + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class VectorizedOrcReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest + with BooleanTypeTest + with IntegralTypeTest + with ToDoubleTypeTest { + + override val format: String = "orc" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class ParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "false") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class VectorizedParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class MergedParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED) + spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, originalConf) + super.afterAll() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala new file mode 100644 index 0000000000000..2a5457e00b4ef --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} + +/** + * The reader schema is said to be evolved (or projected) when it changed after the data is + * written by writers. The followings are supported in file-based data sources. + * Note that partition columns are not maintained in files. Here, `column` means non-partition + * column. + * + * 1. Add a column + * 2. Hide a column + * 3. Change a column position + * 4. Change a column type (Upcast) + * + * Here, we consider safe changes without data loss. For example, data type changes should be + * from small types to larger types like `int`-to-`long`, not vice versa. + * + * So far, file-based data sources have the following coverages. + * + * | File Format | Coverage | Note | + * | ------------ | ------------ | ------------------------------------------------------ | + * | TEXT | N/A | Schema consists of a single string column. | + * | CSV | 1, 2, 4 | | + * | JSON | 1, 2, 3, 4 | | + * | ORC | 1, 2, 3, 4 | Native vectorized ORC reader has the widest coverage. | + * | PARQUET | 1, 2, 3 | | + * + * This aims to provide an explicit test coverage for reader schema change on file-based data + * sources. Since a file format has its own coverage, we need a test suite for each file-based + * data source with corresponding supported test case traits. + * + * The following is a hierarchy of test traits. + * + * ReadSchemaTest + * -> AddColumnTest + * -> HideColumnTest + * -> ChangePositionTest + * -> BooleanTypeTest + * -> IntegralTypeTest + * -> ToDoubleTypeTest + * -> ToDecimalTypeTest + */ + +trait ReadSchemaTest extends QueryTest with SQLTestUtils with SharedSQLContext { + val format: String + val options: Map[String, String] = Map.empty[String, String] +} + +/** + * Add column (Case 1). + * This test suite assumes that the missing column should be `null`. + */ +trait AddColumnTest extends ReadSchemaTest { + import testImplicits._ + + test("append column at the end") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq("a", "b").toDF("col1") + val df2 = df1.withColumn("col2", lit("x")) + val df3 = df2.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=one" + val dir2 = s"$path${File.separator}part=two" + val dir3 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + df3.write.format(format).options(options).save(dir3) + + val df = spark.read + .schema(df3.schema) + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("a", null, null, "one"), + Row("b", null, null, "one"), + Row("a", "x", null, "two"), + Row("b", "x", null, "two"), + Row("a", "x", "y", "three"), + Row("b", "x", "y", "three"))) + } + } +} + +/** + * Hide column (Case 2-1). + */ +trait HideColumnAtTheEndTest extends ReadSchemaTest { + import testImplicits._ + + test("hide column at the end") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b")).toDF("col1", "col2") + val df2 = df1.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=two" + val dir2 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema(df1.schema) + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("1", "a", "two"), + Row("2", "b", "two"), + Row("1", "a", "three"), + Row("2", "b", "three"))) + + val df3 = spark.read + .schema("col1 string") + .format(format) + .options(options) + .load(path) + + checkAnswer(df3, Seq( + Row("1", "two"), + Row("2", "two"), + Row("1", "three"), + Row("2", "three"))) + } + } +} + +/** + * Hide column in the middle (Case 2-2). + */ +trait HideColumnInTheMiddleTest extends ReadSchemaTest { + import testImplicits._ + + test("hide column in the middle") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b")).toDF("col1", "col2") + val df2 = df1.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=two" + val dir2 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema("col2 string") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("a", "two"), + Row("b", "two"), + Row("a", "three"), + Row("b", "three"))) + } + } +} + +/** + * Change column positions (Case 3). + * This suite assumes that all data set have the same number of columns. + */ +trait ChangePositionTest extends ReadSchemaTest { + import testImplicits._ + + test("change column position") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b"), ("3", "c")).toDF("col1", "col2") + val df2 = Seq(("d", "4"), ("e", "5"), ("f", "6")).toDF("col2", "col1") + val unionDF = df1.unionByName(df2) + + val dir1 = s"$path${File.separator}part=one" + val dir2 = s"$path${File.separator}part=two" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1", "col2") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait BooleanTypeTest extends ReadSchemaTest { + import testImplicits._ + + test("change column type from boolean to byte/short/int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val values = (1 to 10).map(_ % 2) + val booleanDF = (1 to 10).map(_ % 2 == 1).toDF("col1") + val byteDF = values.map(_.toByte).toDF("col1") + val shortDF = values.map(_.toShort).toDF("col1") + val intDF = values.toDF("col1") + val longDF = values.map(_.toLong).toDF("col1") + + booleanDF.write.mode("overwrite").format(format).options(options).save(path) + + Seq( + ("col1 byte", byteDF), + ("col1 short", shortDF), + ("col1 int", intDF), + ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToStringTypeTest extends ReadSchemaTest { + import testImplicits._ + + test("read as string") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val byteDF = (Byte.MaxValue - 2 to Byte.MaxValue).map(_.toByte).toDF("col1") + val shortDF = (Short.MaxValue - 2 to Short.MaxValue).map(_.toShort).toDF("col1") + val intDF = (Int.MaxValue - 2 to Int.MaxValue).toDF("col1") + val longDF = (Long.MaxValue - 2 to Long.MaxValue).toDF("col1") + val unionDF = byteDF.union(shortDF).union(intDF).union(longDF) + .selectExpr("cast(col1 AS STRING) col1") + + val byteDir = s"$path${File.separator}part=byte" + val shortDir = s"$path${File.separator}part=short" + val intDir = s"$path${File.separator}part=int" + val longDir = s"$path${File.separator}part=long" + + byteDF.write.format(format).options(options).save(byteDir) + shortDF.write.format(format).options(options).save(shortDir) + intDF.write.format(format).options(options).save(intDir) + longDF.write.format(format).options(options).save(longDir) + + val df = spark.read + .schema("col1 string") + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait IntegralTypeTest extends ReadSchemaTest { + + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val byteDF = values.map(_.toByte).toDF("col1") + private lazy val shortDF = values.map(_.toShort).toDF("col1") + private lazy val intDF = values.toDF("col1") + private lazy val longDF = values.map(_.toLong).toDF("col1") + + test("change column type from byte to short/int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + byteDF.write.format(format).options(options).save(path) + + Seq( + ("col1 short", shortDF), + ("col1 int", intDF), + ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("change column type from short to int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + shortDF.write.format(format).options(options).save(path) + + Seq(("col1 int", intDF), ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("change column type from int to long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + intDF.write.format(format).options(options).save(path) + + Seq(("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("read byte, int, short, long together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val byteDF = (Byte.MaxValue - 2 to Byte.MaxValue).map(_.toByte).toDF("col1") + val shortDF = (Short.MaxValue - 2 to Short.MaxValue).map(_.toShort).toDF("col1") + val intDF = (Int.MaxValue - 2 to Int.MaxValue).toDF("col1") + val longDF = (Long.MaxValue - 2 to Long.MaxValue).toDF("col1") + val unionDF = byteDF.union(shortDF).union(intDF).union(longDF) + + val byteDir = s"$path${File.separator}part=byte" + val shortDir = s"$path${File.separator}part=short" + val intDir = s"$path${File.separator}part=int" + val longDir = s"$path${File.separator}part=long" + + byteDF.write.format(format).options(options).save(byteDir) + shortDF.write.format(format).options(options).save(shortDir) + intDF.write.format(format).options(options).save(intDir) + longDF.write.format(format).options(options).save(longDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToDoubleTypeTest extends ReadSchemaTest { + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val floatDF = values.map(_.toFloat).toDF("col1") + private lazy val doubleDF = values.map(_.toDouble).toDF("col1") + private lazy val unionDF = floatDF.union(doubleDF) + + test("change column type from float to double") { + withTempPath { dir => + val path = dir.getCanonicalPath + + floatDF.write.format(format).options(options).save(path) + + val df = spark.read.schema("col1 double").format(format).options(options).load(path) + + checkAnswer(df, doubleDF) + } + } + + test("read float and double together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val floatDir = s"$path${File.separator}part=float" + val doubleDir = s"$path${File.separator}part=double" + + floatDF.write.format(format).options(options).save(floatDir) + doubleDF.write.format(format).options(options).save(doubleDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToDecimalTypeTest extends ReadSchemaTest { + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val floatDF = values.map(_.toFloat).toDF("col1") + private lazy val doubleDF = values.map(_.toDouble).toDF("col1") + private lazy val decimalDF = values.map(BigDecimal(_)).toDF("col1") + private lazy val unionDF = floatDF.union(doubleDF).union(decimalDF) + + test("change column type from float to decimal") { + withTempPath { dir => + val path = dir.getCanonicalPath + + floatDF.write.format(format).options(options).save(path) + + val df = spark.read + .schema("col1 decimal(38,18)") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, decimalDF) + } + } + + test("change column type from double to decimal") { + withTempPath { dir => + val path = dir.getCanonicalPath + + doubleDF.write.format(format).options(options).save(path) + + val df = spark.read + .schema("col1 decimal(38,18)") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, decimalDF) + } + } + + test("read float, double, decimal together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val floatDir = s"$path${File.separator}part=float" + val doubleDir = s"$path${File.separator}part=double" + val decimalDir = s"$path${File.separator}part=decimal" + + floatDF.write.format(format).options(options).save(floatDir) + doubleDF.write.format(format).options(options).save(doubleDir) + decimalDF.write.format(format).options(options).save(decimalDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala index 4b3ca8e60cab6..a1da3ec43eae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala @@ -23,9 +23,6 @@ import org.apache.spark.sql.test.SharedSQLContext class SaveIntoDataSourceCommandSuite extends SharedSQLContext { - override protected def sparkConf: SparkConf = super.sparkConf - .set("spark.redaction.regex", "(?i)password|url") - test("simpleString is redacted") { val URL = "connection.url" val PASS = "123" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala new file mode 100644 index 0000000000000..1a3dacb8398e6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.csv + +import java.io.File + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{Column, Row, SparkSession} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Benchmark, Utils} + +/** + * Benchmark to measure CSV read/write performance. + * To run this: + * spark-submit --class --jars + */ +object CSVBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-csv-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark(s"Parsing quoted values", rowsNum) + + withTempPath { path => + val str = (0 until 10000).map(i => s""""$i"""").mkString(",") + + spark.range(rowsNum) + .map(_ => str) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val schema = new StructType().add("value", StringType) + val ds = spark.read.option("header", true).schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"One quoted string", numIters) { _ => + ds.filter((_: Row) => true).count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Parsing quoted values: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + One quoted string 30273 / 30549 0.0 605451.2 1.0X + */ + benchmark.run() + } + } + + def multiColumnsBenchmark(rowsNum: Int): Unit = { + val colsNum = 1000 + val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val values = (0 until colsNum).map(i => i.toString).mkString(",") + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns", 3) { _ => + ds.select("*").filter((row: Row) => true).count() + } + val cols100 = columnNames.take(100).map(Column(_)) + benchmark.addCase(s"Select 100 columns", 3) { _ => + ds.select(cols100: _*).filter((row: Row) => true).count() + } + benchmark.addCase(s"Select one column", 3) { _ => + ds.select($"col1").filter((row: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Select 1000 columns 81091 / 81692 0.0 81090.7 1.0X + Select 100 columns 30003 / 34448 0.0 30003.0 2.7X + Select one column 24792 / 24855 0.0 24792.0 3.3X + count() 24344 / 24642 0.0 24343.8 3.3X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + multiColumnsBenchmark(rowsNum = 1000 * 1000) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 661742087112f..842251be92c18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types._ class CSVInferSchemaSuite extends SparkFunSuite { test("String fields types are inferred correctly from null types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(NullType, "", options) == NullType) assert(CSVInferSchema.inferField(NullType, null, options) == NullType) assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType) @@ -41,7 +41,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("String fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType) assert(CSVInferSchema.inferField(LongType, "test", options) == StringType) assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType) @@ -60,21 +60,21 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Timestamp field types are inferred correctly via custom data format") { - var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), "GMT") + var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) - options = new CSVOptions(Map("timestampFormat" -> "yyyy"), "GMT") + options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) } test("Timestamp field types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType) assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType) } test("Boolean fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType) assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType) } @@ -92,12 +92,12 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Null fields are handled properly when a nullValue is specified") { - var options = new CSVOptions(Map("nullValue" -> "null"), "GMT") + var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") assert(CSVInferSchema.inferField(NullType, "null", options) == NullType) assert(CSVInferSchema.inferField(StringType, "null", options) == StringType) assert(CSVInferSchema.inferField(LongType, "null", options) == LongType) - options = new CSVOptions(Map("nullValue" -> "\\N"), "GMT") + options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT") assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType) assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) @@ -111,12 +111,12 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { - val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), "GMT") + val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) } test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9). assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) == @@ -134,7 +134,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { test("DoubleType should be infered when user defined nan/inf are provided") { val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf", - "positiveInf" -> "inf"), "GMT") + "positiveInf" -> "inf"), false, "GMT") assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType) assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType) assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4398e547d9217..84b91f6309fe8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -23,19 +23,22 @@ import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Locale +import scala.collection.JavaConverters._ + import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.log4j.{AppenderSkeleton, LogManager} +import org.apache.log4j.spi.LoggingEvent import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { +class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with TestCsvData { import testImplicits._ private val carsFile = "test-data/cars.csv" @@ -261,14 +264,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { multiLine => - val cars = spark.read - .format("csv") - .option("multiLine", multiLine) - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + Seq(false, true).foreach { multiLine => + val cars = spark.read + .format("csv") + .option("multiLine", multiLine) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } } @@ -735,39 +740,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(numbers.count() == 8) } - test("error handling for unsupported data types.") { - withTempDir { dir => - val csvDir = new File(dir, "csv").getCanonicalPath - var msg = intercept[UnsupportedOperationException] { - Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support struct data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support map data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support array data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") - .write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support array data type")) - - msg = intercept[UnsupportedOperationException] { - val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) - spark.range(1).write.csv(csvDir) - spark.read.schema(schema).csv(csvDir).collect() - }.getMessage - assert(msg.contains("CSV data source does not support array data type.")) - } - } - test("SPARK-15585 turn off quotations") { val cars = spark.read .format("csv") @@ -1279,4 +1251,332 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil ) } + + test("SPARK-23846: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(path.getCanonicalPath) + assert(readback.schema == new StructType().add("_c0", IntegerType)) + }) + } + + test("SPARK-23846: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(ds) + + assert(readback.schema == new StructType().add("_c0", IntegerType)) + } + + test("SPARK-23846: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", -1).csv(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", 0).csv(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) + + val sampled = spark.read.option("inferSchema", true).option("samplingRatio", 1.0).csv(ds) + assert(sampled.count() == ds.count()) + } + + test("SPARK-17916: An empty string should not be coerced to null when nullValue is passed.") { + val litNull: String = null + val df = Seq( + (1, "John Doe"), + (2, ""), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + // Checks for new behavior where an empty string is not coerced to null when `nullValue` is + // set to anything but an empty string literal. + withTempPath { path => + df.write + .option("nullValue", "-") + .csv(path.getAbsolutePath) + val computed = spark.read + .option("nullValue", "-") + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, ""), + (3, litNull), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + // Keeps the old behavior where empty string us coerced to nullValue is not passed. + withTempPath { path => + df.write + .csv(path.getAbsolutePath) + val computed = spark.read + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, litNull), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + } + + test("SPARK-24329: skip lines with comments, and one or multiple whitespaces") { + val schema = new StructType().add("colA", StringType) + val ds = spark + .read + .schema(schema) + .option("multiLine", false) + .option("header", true) + .option("comment", "#") + .option("ignoreLeadingWhiteSpace", false) + .option("ignoreTrailingWhiteSpace", false) + .csv(testFile("test-data/comments-whitespaces.csv")) + + checkAnswer(ds, Seq(Row(""" "a" """))) + } + + test("SPARK-24244: Select a subset of all columns") { + withTempPath { path => + import collection.JavaConverters._ + val schema = new StructType() + .add("f1", IntegerType).add("f2", IntegerType).add("f3", IntegerType) + .add("f4", IntegerType).add("f5", IntegerType).add("f6", IntegerType) + .add("f7", IntegerType).add("f8", IntegerType).add("f9", IntegerType) + .add("f10", IntegerType).add("f11", IntegerType).add("f12", IntegerType) + .add("f13", IntegerType).add("f14", IntegerType).add("f15", IntegerType) + + val odf = spark.createDataFrame(List( + Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), + Row(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15) + ).asJava, schema) + odf.write.csv(path.getCanonicalPath) + val idf = spark.read + .schema(schema) + .csv(path.getCanonicalPath) + .select('f15, 'f10, 'f5) + + assert(idf.count() == 2) + checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5))) + } + } + + def checkHeader(multiLine: Boolean): Unit = { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + val exception = intercept[SparkException] { + spark.read + .schema(ischema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + + val shortSchema = new StructType().add("f1", DoubleType) + val exceptionForShortSchema = intercept[SparkException] { + spark.read + .schema(shortSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exceptionForShortSchema.getMessage.contains( + "Number of column in CSV header is not equal to number of fields in the schema")) + + val longSchema = new StructType() + .add("f1", DoubleType) + .add("f2", DoubleType) + .add("f3", DoubleType) + + val exceptionForLongSchema = intercept[SparkException] { + spark.read + .schema(longSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exceptionForLongSchema.getMessage.contains("Header length: 2, schema size: 3")) + + val caseSensitiveSchema = new StructType().add("F1", DoubleType).add("f2", DoubleType) + val caseSensitiveException = intercept[SparkException] { + spark.read + .schema(caseSensitiveSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(caseSensitiveException.getMessage.contains( + "CSV header does not conform to the schema")) + } + } + } + + test(s"SPARK-23786: Checking column names against schema in the multiline mode") { + checkHeader(multiLine = true) + } + + test(s"SPARK-23786: Checking column names against schema in the per-line mode") { + checkHeader(multiLine = false) + } + + test("SPARK-23786: CSV header must not be checked if it doesn't exist") { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", false).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + val idf = spark.read + .schema(ischema) + .option("header", false) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + + checkAnswer(idf, odf) + } + } + + test("SPARK-23786: Ignore column name case if spark.sql.caseSensitive is false") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + val oschema = new StructType().add("A", StringType) + val odf = spark.createDataFrame(List(Row("0")).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("a", StringType) + val idf = spark.read.schema(ischema) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + checkAnswer(idf, odf) + } + } + } + + test("SPARK-23786: check header on parsing of dataset of strings") { + val ds = Seq("columnA,columnB", "1.0,1000.0").toDS() + val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType) + val exception = intercept[IllegalArgumentException] { + spark.read.schema(ischema).option("header", true).option("enforceSchema", false).csv(ds) + } + + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + } + + test("SPARK-23786: enforce inferred schema") { + val expectedSchema = new StructType().add("_c0", DoubleType).add("_c1", StringType) + val withHeader = spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .csv(Seq("_c0,_c1", "1.0,a").toDS()) + assert(withHeader.schema == expectedSchema) + checkAnswer(withHeader, Seq(Row(1.0, "a"))) + + // Ignore the inferSchema flag if an user sets a schema + val schema = new StructType().add("colA", DoubleType).add("colB", StringType) + val ds = spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .schema(schema) + .csv(Seq("colA,colB", "1.0,a").toDS()) + assert(ds.schema == schema) + checkAnswer(ds, Seq(Row(1.0, "a"))) + + val exception = intercept[IllegalArgumentException] { + spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .schema(schema) + .csv(Seq("col1,col2", "1.0,a").toDS()) + } + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + } + + test("SPARK-23786: warning should be printed if CSV header doesn't conform to schema") { + class TestAppender extends AppenderSkeleton { + var events = new java.util.ArrayList[LoggingEvent] + override def close(): Unit = {} + override def requiresLayout: Boolean = false + protected def append(event: LoggingEvent): Unit = events.add(event) + } + + val testAppender1 = new TestAppender + LogManager.getRootLogger.addAppender(testAppender1) + try { + val ds = Seq("columnA,columnB", "1.0,1000.0").toDS() + val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType) + + spark.read.schema(ischema).option("header", true).option("enforceSchema", true).csv(ds) + } finally { + LogManager.getRootLogger.removeAppender(testAppender1) + } + assert(testAppender1.events.asScala + .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) + + val testAppender2 = new TestAppender + LogManager.getRootLogger.addAppender(testAppender2) + try { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + spark.read + .schema(ischema) + .option("header", true) + .option("enforceSchema", true) + .csv(path.getCanonicalPath) + .collect() + } + } finally { + LogManager.getRootLogger.removeAppender(testAppender2) + } + assert(testAppender2.events.asScala + .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) + } + + test("SPARK-24645 skip parsing when columnPruning enabled and partitions scanned only") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "true") { + withTempPath { path => + val dir = path.getAbsolutePath + spark.range(10).selectExpr("id % 2 AS p", "id").write.partitionBy("p").csv(dir) + checkAnswer(spark.read.csv(dir).selectExpr("sum(p)"), Row(5)) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala new file mode 100644 index 0000000000000..3e20cc47dca2c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} + +private[csv] trait TestCsvData { + protected def spark: SparkSession + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + index.toString + } else { + (index.toDouble + 0.1).toString + } + }(Encoders.STRING) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index efbf73534bd19..458edb253fb33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -26,8 +26,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParserSuite extends SparkFunSuite { - private val parser = - new UnivocityParser(StructType(Seq.empty), new CSVOptions(Map.empty[String, String], "GMT")) + private val parser = new UnivocityParser( + StructType(Seq.empty), + new CSVOptions(Map.empty[String, String], false, "GMT")) private def assertNull(v: Any) = assert(v == null) @@ -38,7 +39,7 @@ class UnivocityParserSuite extends SparkFunSuite { stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => val decimalValue = new BigDecimal(decimalVal.toString) - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(parser.makeConverter("_1", decimalType, options = options).apply(strVal) === Decimal(decimalValue, decimalType.precision, decimalType.scale)) } @@ -51,21 +52,21 @@ class UnivocityParserSuite extends SparkFunSuite { // Nullable field with nullValue option. types.foreach { t => // Tests that a custom nullValue. - val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), "GMT") + val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") val converter = parser.makeConverter("_1", t, nullable = true, options = nullValueOptions) assertNull(converter.apply("-")) assertNull(converter.apply(null)) // Tests that the default nullValue is empty string. - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assertNull(parser.makeConverter("_1", t, nullable = true, options = options).apply("")) } // Not nullable field with nullValue option. types.foreach { t => // Casts a null to not nullable field should throw an exception. - val options = new CSVOptions(Map("nullValue" -> "-"), "GMT") + val options = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") val converter = parser.makeConverter("_1", t, nullable = false, options = options) var message = intercept[RuntimeException] { @@ -81,7 +82,7 @@ class UnivocityParserSuite extends SparkFunSuite { // If nullValue is different with empty string, then, empty string should not be casted into // null. Seq(true, false).foreach { b => - val options = new CSVOptions(Map("nullValue" -> "null"), "GMT") + val options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") val converter = parser.makeConverter("_1", StringType, nullable = b, options = options) assert(converter.apply("") == UTF8String.fromString("")) @@ -89,7 +90,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Throws exception for empty string with non null type") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") val exception = intercept[RuntimeException]{ parser.makeConverter("_1", IntegerType, nullable = false, options = options).apply("") } @@ -97,7 +98,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Types are cast correctly") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(parser.makeConverter("_1", ByteType, options = options).apply("10") == 10) assert(parser.makeConverter("_1", ShortType, options = options).apply("10") == 10) assert(parser.makeConverter("_1", IntegerType, options = options).apply("10") == 10) @@ -107,7 +108,7 @@ class UnivocityParserSuite extends SparkFunSuite { assert(parser.makeConverter("_1", BooleanType, options = options).apply("true") == true) val timestampsOptions = - new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), "GMT") + new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, "GMT") val customTimestamp = "31/01/2015 00:00" val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime val castedTimestamp = @@ -116,7 +117,7 @@ class UnivocityParserSuite extends SparkFunSuite { assert(castedTimestamp == expectedTime * 1000L) val customDate = "31/01/2015" - val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), "GMT") + val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), false, "GMT") val expectedDate = dateOptions.dateFormat.parse(customDate).getTime val castedDate = parser.makeConverter("_1", DateType, nullable = true, options = dateOptions) @@ -131,7 +132,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Throws exception for casting an invalid string to Float and Double Types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") val types = Seq(DoubleType, FloatType) val input = Seq("10u000", "abc", "1 2/3") types.foreach { dt => @@ -145,7 +146,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Float NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "nn"), "GMT") + val options = new CSVOptions(Map("nanValue" -> "nn"), false, "GMT") val floatVal: Float = parser.makeConverter( "_1", FloatType, nullable = true, options = options ).apply("nn").asInstanceOf[Float] @@ -156,7 +157,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Double NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "-"), "GMT") + val options = new CSVOptions(Map("nanValue" -> "-"), false, "GMT") val doubleVal: Double = parser.makeConverter( "_1", DoubleType, nullable = true, options = options ).apply("-").asInstanceOf[Double] @@ -165,14 +166,14 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Float infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), "GMT") + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") val floatVal1 = parser.makeConverter( "_1", FloatType, nullable = true, options = negativeInfOptions ).apply("max").asInstanceOf[Float] assert(floatVal1 == Float.NegativeInfinity) - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), "GMT") + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") val floatVal2 = parser.makeConverter( "_1", FloatType, nullable = true, options = positiveInfOptions ).apply("max").asInstanceOf[Float] @@ -181,14 +182,14 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Double infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), "GMT") + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") val doubleVal1 = parser.makeConverter( "_1", DoubleType, nullable = true, options = negativeInfOptions ).apply("max").asInstanceOf[Double] assert(doubleVal1 == Double.NegativeInfinity) - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), "GMT") + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") val doubleVal2 = parser.makeConverter( "_1", DoubleType, nullable = true, options = positiveInfOptions ).apply("max").asInstanceOf[Double] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala new file mode 100644 index 0000000000000..85cf054e51f6b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.json + +import java.io.File + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.{LongType, StringType, StructType} +import org.apache.spark.util.{Benchmark, Utils} + +/** + * The benchmarks aims to measure performance of JSON parsing when encoding is set and isn't. + * To run this: + * spark-submit --class --jars + */ +object JSONBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-json-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + + def schemaInferring(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON schema inferring", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + + benchmark.addCase("No encoding", 3) { _ => + spark.read.json(path.getAbsolutePath) + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 38902 / 39282 2.6 389.0 1.0X + UTF-8 is set 56959 / 57261 1.8 569.6 0.7X + */ + benchmark.run() + } + } + + def perlineParsing(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON per-line parsing", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write.json(path.getAbsolutePath) + val schema = new StructType().add("fieldA", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 25947 / 26188 3.9 259.5 1.0X + UTF-8 is set 46319 / 46417 2.2 463.2 0.6X + */ + benchmark.run() + } + } + + def perlineParsingOfWideColumn(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON parsing of wide lines", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map { i => + val s = "abcdef0123456789ABCDEF" * 20 + s"""{"a":"$s","b": $i,"c":"$s","d":$i,"e":"$s","f":$i,"x":"$s","y":$i,"z":"$s"}""" + } + .toDF().write.text(path.getAbsolutePath) + val schema = new StructType() + .add("a", StringType).add("b", LongType) + .add("c", StringType).add("d", LongType) + .add("e", StringType).add("f", LongType) + .add("x", StringType).add("y", LongType) + .add("z", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 45543 / 45660 0.2 4554.3 1.0X + UTF-8 is set 65737 / 65957 0.2 6573.7 0.7X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + schemaInferring(100 * 1000 * 1000) + perlineParsing(100 * 1000 * 1000) + perlineParsingOfWideColumn(10 * 1000 * 1000) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 70aee561ff0f6..eab15b35c97d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.{File, StringWriter} -import java.nio.charset.StandardCharsets -import java.nio.file.{Files, Paths, StandardOpenOption} +import java.io._ +import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException} +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.Locale @@ -31,11 +31,11 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.compatibleType import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -48,6 +48,10 @@ class TestFileFilter extends PathFilter { class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ + def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } + test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { assert(expected.getClass == actual.getClass, @@ -2128,38 +2132,366 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } - test("SPARK-23849: schema inferring touches less data if samplingRation < 1.0") { - val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, - 57, 62, 68, 72) + test("SPARK-23849: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + val readback = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) + + assert(readback.schema == new StructType().add("f1", LongType)) + }) + } + + test("SPARK-23849: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read.option("samplingRatio", 0.1).json(ds) + + assert(readback.schema == new StructType().add("f1", LongType)) + } + + test("SPARK-23849: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", -1).json(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", 0).json(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) + + val sampled = spark.read.option("samplingRatio", 1.0).json(ds) + assert(sampled.count() == ds.count()) + } + + test("SPARK-23723: json in UTF-16 with BOM") { + val fileName = "test-data/utf16WithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .option("encoding", "UTF-16") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"), Row("Doug", "Rood"))) + } + + test("SPARK-23723: multi-line json in UTF-32BE with BOM") { + val fileName = "test-data/utf32BEWithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Use user's encoding in reading of multi-line json in UTF-16LE") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16LE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Unsupported encoding name") { + val invalidCharset = "UTF-128" + val exception = intercept[UnsupportedCharsetException] { + spark.read + .options(Map("encoding" -> invalidCharset, "lineSep" -> "\n")) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains(invalidCharset)) + } + + test("SPARK-23723: checking that the encoding option is case agnostic") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "uTf-16lE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + + test("SPARK-23723: specified encoding is not matched to actual encoding") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val exception = intercept[SparkException] { + spark.read.schema(schema) + .option("mode", "FAILFAST") + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16BE")) + .json(testFile(fileName)) + .count() + } + val errMsg = exception.getMessage + + assert(errMsg.contains("Malformed records are detected in record parsing")) + } + + def checkEncoding(expectedEncoding: String, pathToJsonFiles: String, + expectedContent: String): Unit = { + val jsonFiles = new File(pathToJsonFiles) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("json")) + val actualContent = jsonFiles.map { file => + new String(Files.readAllBytes(file.toPath), expectedEncoding) + }.mkString.trim + + assert(actualContent == expectedContent) + } + + test("SPARK-23723: save json in UTF-32BE") { + val encoding = "UTF-32BE" + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write + .options(Map("encoding" -> encoding)) + .json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = encoding, + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: save json in default encoding - UTF-8") { + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write.json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = "UTF-8", + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: wrong output encoding") { + val encoding = "UTF-128" + val exception = intercept[SparkException] { + withTempPath { path => + val df = spark.createDataset(Seq((0))) + df.write + .options(Map("encoding" -> encoding)) + .json(path.getCanonicalPath) + } + } + + val baos = new ByteArrayOutputStream() + val ps = new PrintStream(baos, true, "UTF-8") + exception.printStackTrace(ps) + ps.flush() + + assert(baos.toString.contains( + "java.nio.charset.UnsupportedCharsetException: UTF-128")) + } + + test("SPARK-23723: read back json in UTF-16LE") { + val options = Map("encoding" -> "UTF-16LE", "lineSep" -> "\n") withTempPath { path => - val writer = Files.newBufferedWriter(Paths.get(path.getAbsolutePath), - StandardCharsets.UTF_8, StandardOpenOption.CREATE_NEW) - for (i <- 0 until 100) { - if (predefinedSample.contains(i)) { - writer.write(s"""{"f1":${i.toString}}""" + "\n") + val ds = spark.createDataset(Seq(("a", 1), ("b", 2), ("c", 3))).repartition(2) + ds.write.options(options).json(path.getCanonicalPath) + + val readBack = spark + .read + .options(options) + .json(path.getCanonicalPath) + + checkAnswer(readBack.toDF(), ds.toDF()) + } + } + + test("SPARK-23723: write json in UTF-16/32 with multiline off") { + Seq("UTF-16", "UTF-32").foreach { encoding => + withTempPath { path => + val ds = spark.createDataset(Seq(("a", 1))).repartition(1) + ds.write + .option("encoding", encoding) + .option("multiline", false) + .json(path.getCanonicalPath) + val jsonFiles = path.listFiles().filter(_.getName.endsWith("json")) + jsonFiles.foreach { jsonFile => + val readback = Files.readAllBytes(jsonFile.toPath) + val expected = ("""{"_1":"a","_2":1}""" + "\n").getBytes(Charset.forName(encoding)) + assert(readback === expected) + } + } + } + } + + def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { + test(s"SPARK-23724: checks reading json in ${encoding} #${id}") { + val schema = new StructType().add("f1", StringType).add("f2", IntegerType) + withTempPath { path => + val records = List(("a", 1), ("b", 2)) + val data = records + .map(rec => s"""{"f1":"${rec._1}", "f2":${rec._2}}""".getBytes(encoding)) + .reduce((a1, a2) => a1 ++ lineSep.getBytes(encoding) ++ a2) + val os = new FileOutputStream(path) + os.write(data) + os.close() + val reader = if (inferSchema) { + spark.read } else { - writer.write(s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n") + spark.read.schema(schema) } + val readBack = reader + .option("encoding", encoding) + .option("lineSep", lineSep) + .json(path.getCanonicalPath) + checkAnswer(readBack, records.map(rec => Row(rec._1, rec._2))) } - writer.close() + } + } + + // scalastyle:off nonascii + List( + (0, "|", "UTF-8", false), + (1, "^", "UTF-16BE", true), + (2, "::", "ISO-8859-1", true), + (3, "!!!@3", "UTF-32LE", false), + (4, 0x1E.toChar.toString, "UTF-8", true), + (5, "아", "UTF-32BE", false), + (6, "куку", "CP1251", true), + (7, "sep", "utf-8", false), + (8, "\r\n", "UTF-16LE", false), + (9, "\r\n", "utf-16be", true), + (10, "\u000d\u000a", "UTF-32BE", false), + (11, "\u000a\u000d", "UTF-8", true), + (12, "===", "US-ASCII", false), + (13, "$^+", "utf-32le", true) + ).foreach { + case (testNum, sep, encoding, inferSchema) => checkReadJson(sep, encoding, inferSchema, testNum) + } + // scalastyle:on nonascii + + test("SPARK-23724: lineSep should be set if encoding if different from UTF-8") { + val encoding = "UTF-16LE" + val exception = intercept[IllegalArgumentException] { + spark.read + .options(Map("encoding" -> encoding)) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains( + s"""The lineSep option must be specified for the $encoding encoding""")) + } - val ds = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) - assert(ds.schema == new StructType().add("f1", LongType)) + private val badJson = "\u0000\u0000\u0000A\u0001AAA" + + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is enabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson + """{"a":1}""").toDS().write.text(path) + val expected = s"""${badJson}{"a":1}\n""" + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", true) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Row(null, expected)) } } - test("SPARK-23849: usage of samplingRation while parsing of dataset of strings") { - val dstr = spark.sparkContext.parallelize(0 until 100, 1).map { i => - val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, - 57, 62, 68, 72) - if (predefinedSample.contains(i)) { - s"""{"f1":${i.toString}}""" + "\n" - } else { - s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n" - } - }.toDS() - val ds = spark.read.option("samplingRatio", 0.1).json(dstr) + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is disabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.text(path) + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", false) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) + } + } + + test("SPARK-23094: permissively parse a dataset contains JSON with leading nulls") { + checkAnswer( + spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), + Row(badJson)) + } + + test("SPARK-23772 ignore column of all null values or empty array during schema inference") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + + // primitive types + Seq( + """{"a":null, "b":1, "c":3.0}""", + """{"a":null, "b":null, "c":"string"}""", + """{"a":null, "b":null, "c":null}""") + .toDS().write.text(path) + var df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + var expectedSchema = new StructType() + .add("b", LongType).add("c", StringType) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(1, "3.0") :: Row(null, "string") :: Row(null, null) :: Nil) - assert(ds.schema == new StructType().add("f1", LongType)) + // arrays + Seq( + """{"a":[2, 1], "b":[null, null], "c":null, "d":[[], [null]], "e":[[], null, [[]]]}""", + """{"a":[null], "b":[null], "c":[], "d":[null, []], "e":null}""", + """{"a":null, "b":null, "c":[], "d":null, "e":[null, [], null]}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", ArrayType(LongType)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Array(2, 1)) :: Row(Array(null)) :: Row(null) :: Nil) + + // structs + Seq( + """{"a":{"a1": 1, "a2":"string"}, "b":{}}""", + """{"a":{"a1": 2, "a2":null}, "b":{"b1":[null]}}""", + """{"a":null, "b":null}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", StructType(StructField("a1", LongType) :: StructField("a2", StringType) + :: Nil)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil) + } + } + + test("SPARK-24190: restrictions for JSONOptions in read") { + for (encoding <- Set("UTF-16", "UTF-32")) { + val exception = intercept[IllegalArgumentException] { + spark.read + .option("encoding", encoding) + .option("multiLine", false) + .json(testFile("test-data/utf16LE.json")) + .count() + } + assert(exception.getMessage.contains("encoding must not be included in the blacklist")) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 13084ba4a7f04..6e9559edf8ec2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -233,4 +233,16 @@ private[json] trait TestJsonData { spark.createDataset(spark.sparkContext.parallelize("""{"a":123}""" :: Nil))(Encoders.STRING) def empty: Dataset[String] = spark.emptyDataset(Encoders.STRING) + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + s"""{"f1":${index.toString}}""" + } else { + s"""{"f1":${(index.toDouble + 0.1).toString}}""" + } + }(Encoders.STRING) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 8a3bbd03a26dc..02bfb7197ffc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File +import java.sql.Timestamp import java.util.Locale import org.apache.orc.OrcConf.COMPRESS @@ -169,6 +170,14 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { } } } + + test("SPARK-24322 Fix incorrect workaround for bug in java.sql.Timestamp") { + withTempPath { path => + val ts = Timestamp.valueOf("1900-05-05 12:34:56.000789") + Seq(ts).toDF.write.orc(path.getCanonicalPath) + checkAnswer(spark.read.orc(path.getCanonicalPath), Row(ts)) + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala index f3ecc5ced689f..4b2437803d645 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -91,9 +91,14 @@ class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils summary: Boolean, check: Boolean): Option[FileStatus] = { var result: Option[FileStatus] = None + val summaryLevel = if (summary) { + "ALL" + } else { + "NONE" + } withSQLConf( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key -> committer, - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> summary.toString) { + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> summaryLevel) { withTempPath { dest => val df = spark.createDataFrame(Seq((1, "4"), (2, "2"))) val destPath = new Path(dest.toURI) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 1d3476e747046..924f136503656 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.charset.StandardCharsets -import java.sql.Date +import java.sql.{Date, Timestamp} -import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} @@ -55,6 +56,10 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { + private lazy val parquetFilters = + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownStringStartWith, conf.parquetFilterPushDownInFilterThreshold) + override def beforeEach(): Unit = { super.beforeEach() // Note that there are many tests here that require record-level filtering set to be true. @@ -80,6 +85,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withSQLConf( SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df .select(output.map(e => Column(e)): _*) @@ -99,7 +106,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(selectedFilters.nonEmpty, "No filter is pushed down") selectedFilters.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + val maybeFilter = parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) maybeFilter.exists(_.getClass === filterClass) @@ -138,6 +146,64 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } + private def testTimestampPushdown(data: Seq[Timestamp]): Unit = { + assert(data.size === 4) + val ts1 = data.head + val ts2 = data(1) + val ts3 = data(2) + val ts4 = data(3) + + withParquetDataFrame(data.map(i => Tuple1(i))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i))) + + checkFilterPredicate('_1 === ts1, classOf[Eq[_]], ts1) + checkFilterPredicate('_1 <=> ts1, classOf[Eq[_]], ts1) + checkFilterPredicate('_1 =!= ts1, classOf[NotEq[_]], + Seq(ts2, ts3, ts4).map(i => Row.apply(i))) + + checkFilterPredicate('_1 < ts2, classOf[Lt[_]], ts1) + checkFilterPredicate('_1 > ts1, classOf[Gt[_]], Seq(ts2, ts3, ts4).map(i => Row.apply(i))) + checkFilterPredicate('_1 <= ts1, classOf[LtEq[_]], ts1) + checkFilterPredicate('_1 >= ts4, classOf[GtEq[_]], ts4) + + checkFilterPredicate(Literal(ts1) === '_1, classOf[Eq[_]], ts1) + checkFilterPredicate(Literal(ts1) <=> '_1, classOf[Eq[_]], ts1) + checkFilterPredicate(Literal(ts2) > '_1, classOf[Lt[_]], ts1) + checkFilterPredicate(Literal(ts3) < '_1, classOf[Gt[_]], ts4) + checkFilterPredicate(Literal(ts1) >= '_1, classOf[LtEq[_]], ts1) + checkFilterPredicate(Literal(ts4) <= '_1, classOf[GtEq[_]], ts4) + + checkFilterPredicate(!('_1 < ts4), classOf[GtEq[_]], ts4) + checkFilterPredicate('_1 < ts2 || '_1 > ts3, classOf[Operators.Or], Seq(Row(ts1), Row(ts4))) + } + } + + // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. + private def testStringStartsWith(dataFrame: DataFrame, filter: String): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + dataFrame.write.option("parquet.block.size", 512).parquet(path) + Seq(true, false).foreach { pushDown => + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> pushDown.toString) { + val accu = new NumRowGroupsAcc + sparkContext.register(accu) + + val df = spark.read.parquet(path).filter(filter) + df.foreachPartition((it: Iterator[Row]) => it.foreach(v => accu.add(0))) + if (pushDown) { + assert(accu.value == 0) + } else { + assert(accu.value > 0) + } + + AccumulatorContext.remove(accu.id) + } + } + } + } + test("filter pushdown - boolean") { withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) @@ -149,6 +215,62 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - tinyint") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toByte)))) { implicit df => + assert(df.schema.head.dataType === ByteType) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1.toByte, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1.toByte, classOf[Eq[_]], 1) + checkFilterPredicate('_1 =!= 1.toByte, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2.toByte, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3.toByte, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1.toByte, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4.toByte, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1.toByte) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1.toByte) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2.toByte) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3.toByte) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1.toByte) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4.toByte) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4.toByte), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2.toByte || '_1 > 3.toByte, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + + test("filter pushdown - smallint") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => + assert(df.schema.head.dataType === ShortType) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1.toShort, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1.toShort, classOf[Eq[_]], 1) + checkFilterPredicate('_1 =!= 1.toShort, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2.toShort, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3.toShort, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1.toShort, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4.toShort, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1.toShort) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1.toShort) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2.toShort) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3.toShort) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1.toShort) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4.toShort) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4.toShort), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2.toShort || '_1 > 3.toShort, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + test("filter pushdown - integer") { withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) @@ -357,6 +479,39 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - timestamp") { + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MILLIS + val millisData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123"), + Timestamp.valueOf("2018-06-15 08:28:53.123"), + Timestamp.valueOf("2018-06-16 08:28:53.123"), + Timestamp.valueOf("2018-06-17 08:28:53.123")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) { + testTimestampPushdown(millisData) + } + + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MICROS + val microsData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123456"), + Timestamp.valueOf("2018-06-15 08:28:53.123456"), + Timestamp.valueOf("2018-06-16 08:28:53.123456"), + Timestamp.valueOf("2018-06-17 08:28:53.123456")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MICROS.toString) { + testTimestampPushdown(microsData) + } + + // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.INT96.toString) { + withParquetDataFrame(millisData.map(i => Tuple1(i))) { implicit df => + assertResult(None) { + parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), sources.IsNull("_1")) + } + } + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ @@ -513,28 +668,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex StructField("c", DoubleType, nullable = true) )) + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + assertResult(Some(and( lt(intColumn("a"), 10: Integer), gt(doubleColumn("c"), 1.5: java.lang.Double))) ) { - ParquetFilters.createFilter( - schema, + parquetFilters.createFilter( + parquetSchema, sources.And( sources.LessThan("a", 10), sources.GreaterThan("c", 1.5D))) } assertResult(None) { - ParquetFilters.createFilter( - schema, + parquetFilters.createFilter( + parquetSchema, sources.And( sources.LessThan("a", 10), sources.StringContains("b", "prefix"))) } assertResult(None) { - ParquetFilters.createFilter( - schema, + parquetFilters.createFilter( + parquetSchema, sources.Not( sources.And( sources.GreaterThan("a", 1), @@ -572,7 +729,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val df = spark.read.parquet(path).filter("a < 100") df.foreachPartition((it: Iterator[Row]) => it.foreach(v => accu.add(0))) - df.collect if (enablePushDown) { assert(accu.value == 0) @@ -587,21 +743,25 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } test("SPARK-17213: Broken Parquet filter push-down for string columns") { - withTempPath { dir => - import testImplicits._ + Seq(true, false).foreach { vectorizedEnabled => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedEnabled.toString) { + withTempPath { dir => + import testImplicits._ - val path = dir.getCanonicalPath - // scalastyle:off nonascii - Seq("a", "é").toDF("name").write.parquet(path) - // scalastyle:on nonascii + val path = dir.getCanonicalPath + // scalastyle:off nonascii + Seq("a", "é").toDF("name").write.parquet(path) + // scalastyle:on nonascii - assert(spark.read.parquet(path).where("name > 'a'").count() == 1) - assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) + assert(spark.read.parquet(path).where("name > 'a'").count() == 1) + assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) - // scalastyle:off nonascii - assert(spark.read.parquet(path).where("name < 'é'").count() == 1) - assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) - // scalastyle:on nonascii + // scalastyle:off nonascii + assert(spark.read.parquet(path).where("name < 'é'").count() == 1) + assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) + // scalastyle:on nonascii + } + } } } @@ -646,6 +806,133 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("SPARK-23852: Broken Parquet push-down for partially-written stats") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + // parquet-1217.parquet contains a single column with values -1, 0, 1, 2 and null. + // The row-group statistics include null counts, but not min and max values, which + // triggers PARQUET-1217. + val df = readResourceParquetFile("test-data/parquet-1217.parquet") + + // Will return 0 rows if PARQUET-1217 is not fixed. + assert(df.where("col > 0").count() === 2) + } + } + + test("filter pushdown - StringStartsWith") { + withParquetDataFrame((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df => + checkFilterPredicate( + '_1.startsWith("").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) + + Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => + checkFilterPredicate( + '_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + "2str2") + } + + Seq("2S", "null", "2str22").foreach { prefix => + checkFilterPredicate( + '_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq.empty[Row]) + } + + checkFilterPredicate( + !'_1.startsWith("").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq().map(Row(_))) + + Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => + checkFilterPredicate( + !'_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "3str3", "4str4").map(Row(_))) + } + + Seq("2S", "null", "2str22").foreach { prefix => + checkFilterPredicate( + !'_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) + } + + assertResult(None) { + parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), + sources.StringStartsWith("_1", null)) + } + } + + import testImplicits._ + // Test canDrop() has taken effect + testStringStartsWith(spark.range(1024).map(_.toString).toDF(), "value like 'a%'") + // Test inverseCanDrop() has taken effect + testStringStartsWith(spark.range(1024).map(c => "100").toDF(), "value not like '10%'") + } + + test("SPARK-17091: Convert IN predicate to Parquet filter push-down") { + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + assertResult(Some(FilterApi.eq(intColumn("a"), null: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(null))) + } + + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10))) + } + + // Remove duplicates + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 10))) + } + + assertResult(Some(or(or( + FilterApi.eq(intColumn("a"), 10: Integer), + FilterApi.eq(intColumn("a"), 20: Integer)), + FilterApi.eq(intColumn("a"), 30: Integer))) + ) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 20, 30))) + } + + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold).toArray)).isDefined) + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold + 1).toArray)).isEmpty) + + import testImplicits._ + withTempPath { path => + val data = 0 to 1024 + data.toDF("a").selectExpr("if (a = 1024, null, a) AS a") // convert 1024 to null + .coalesce(1).write.option("parquet.block.size", 512) + .parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) + Seq(true, false).foreach { pushEnabled => + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> pushEnabled.toString) { + Seq(1, 5, 10, 11).foreach { count => + val filter = s"a in(${Range(0, count).mkString(",")})" + assert(df.where(filter).count() === count) + val actual = stripSparkFilter(df.where(filter)).collect().length + if (pushEnabled && count <= conf.parquetFilterPushDownInFilterThreshold) { + assert(actual > 1 && actual < data.length) + } else { + assert(actual === data.length) + } + } + assert(df.where("a in(null)").count() === 0) + assert(df.where("a = null").count() === 0) + assert(df.where("a is null").count() === 1) + } + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 0b3e8ca060d87..002c42f23bd64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -543,7 +543,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val hadoopConf = spark.sessionState.newHadoopConfWithOptions(extraOptions) - withSQLConf(ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true") { + withSQLConf(ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part-r-0.parquet" spark.range(1 << 16).selectExpr("(id % 4) AS i") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index e887c9734a8b8..9966ed94a8392 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1014,7 +1014,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val path = dir.getCanonicalPath withSQLConf( - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true", + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL", "spark.sql.sources.commitProtocolClass" -> classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { spark.range(3).write.parquet(s"$path/p0=0/p1=0") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index e1f094d0a7af3..dbf637783e6d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -275,7 +275,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName, SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true", - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true" + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL" ) { testSchemaMerging(2) } @@ -879,6 +879,18 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } } + + test("SPARK-24230: filter row group using dictionary") { + withSQLConf(("parquet.filter.dictionary.enabled", "true")) { + // create a table with values from 0, 2, ..., 18 that will be dictionary-encoded + withParquetTable((0 until 100).map(i => ((i * 2) % 20, s"data-$i")), "t") { + // search for a key that is not present so the dictionary filter eliminates all row groups + // Fails without SPARK-24230: + // java.io.IOException: expecting more rows but reached last block. Read 0 out of 50 + checkAnswer(sql("SELECT _2 FROM t WHERE t._1 = 5"), Seq.empty) + } + } + } } object TestingUDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala deleted file mode 100644 index e43336d947364..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ /dev/null @@ -1,339 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.parquet - -import java.io.File - -import scala.collection.JavaConverters._ -import scala.util.Try - -import org.apache.spark.SparkConf -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{Benchmark, Utils} - -/** - * Benchmark to measure parquet read performance. - * To run this: - * spark-submit --class --jars - */ -object ParquetReadBenchmark { - val conf = new SparkConf() - conf.set("spark.sql.parquet.compression.codec", "snappy") - - val spark = SparkSession.builder - .master("local[1]") - .appName("test-sql-context") - .config(conf) - .getOrCreate() - - // Set default configs. Individual cases will change them if necessary. - spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(spark.catalog.dropTempView) - } - - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - - def intScanBenchmark(values: Int): Unit = { - // Benchmarks running through spark sql. - val sqlBenchmark = new Benchmark("SQL Single Int Column Scan", values) - // Benchmarks driving reader component directly. - val parquetReaderBenchmark = new Benchmark("Parquet Reader Single Int Column Scan", values) - - withTempPath { dir => - withTempTable("t1", "tempTable") { - val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled - val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast(id as INT) as id from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - sqlBenchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(id) from tempTable").collect() - } - - sqlBenchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(id) from tempTable").collect() - } - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - // Driving the parquet reader in batch mode directly. - parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("id" :: Nil).asJava) - val batch = reader.resultBatch() - val col = batch.column(0) - while (reader.nextBatch()) { - val numRows = batch.numRows() - var i = 0 - while (i < numRows) { - if (!col.isNullAt(i)) sum += col.getInt(i) - i += 1 - } - } - } finally { - reader.close() - } - } - } - - // Decoding in vectorized but having the reader return rows. - parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("id" :: Nil).asJava) - val batch = reader.resultBatch() - while (reader.nextBatch()) { - val it = batch.rowIterator() - while (it.hasNext) { - val record = it.next() - if (!record.isNullAt(0)) sum += record.getInt(0) - } - } - } finally { - reader.close() - } - } - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - SQL Single Int Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 215 / 262 73.0 13.7 1.0X - SQL Parquet MR 1946 / 2083 8.1 123.7 0.1X - */ - sqlBenchmark.run() - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Parquet Reader Single Int Column Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - ParquetReader Vectorized 123 / 152 127.8 7.8 1.0X - ParquetReader Vectorized -> Row 165 / 180 95.2 10.5 0.7X - */ - parquetReaderBenchmark.run() - } - } - } - - def intStringScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("Int and String Scan", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect - } - - benchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect - } - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 628 / 720 16.7 59.9 1.0X - SQL Parquet MR 1905 / 2239 5.5 181.7 0.3X - */ - benchmark.run() - } - } - } - - def stringDictionaryScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("String Dictionary", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(length(c1)) from tempTable").collect - } - - benchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(length(c1)) from tempTable").collect - } - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - String Dictionary: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 329 / 337 31.9 31.4 1.0X - SQL Parquet MR 1131 / 1325 9.3 107.8 0.3X - */ - benchmark.run() - } - } - } - - def partitionTableScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select id % 2 as p, cast(id as INT) as id from t1") - .write.partitionBy("p").parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("Partitioned Table", values) - - benchmark.addCase("Read data column") { iter => - spark.sql("select sum(id) from tempTable").collect - } - - benchmark.addCase("Read partition column") { iter => - spark.sql("select sum(p) from tempTable").collect - } - - benchmark.addCase("Read both columns") { iter => - spark.sql("select sum(p), sum(id) from tempTable").collect - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Read data column 191 / 250 82.1 12.2 1.0X - Read partition column 82 / 86 192.4 5.2 2.3X - Read both columns 220 / 248 71.5 14.0 0.9X - */ - benchmark.run() - } - } - } - - def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled - val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize - spark.range(values).createOrReplaceTempView("t1") - spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + - s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("String with Nulls Scan", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(length(c2)) from tempTable where c1 is " + - "not NULL and c2 is not NULL").collect() - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - benchmark.addCase("PR Vectorized") { num => - var sum = 0 - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) - val batch = reader.resultBatch() - while (reader.nextBatch()) { - val rowIterator = batch.rowIterator() - while (rowIterator.hasNext) { - val row = rowIterator.next() - val value = row.getUTF8String(0) - if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() - } - } - } finally { - reader.close() - } - } - } - - /* - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 1229 / 1648 8.5 117.2 1.0X - PR Vectorized 833 / 846 12.6 79.4 1.5X - - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (50%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 995 / 1053 10.5 94.9 1.0X - PR Vectorized 732 / 772 14.3 69.8 1.4X - - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 326 / 333 32.2 31.1 1.0X - PR Vectorized 190 / 200 55.1 18.2 1.7X - */ - - benchmark.run() - } - } - } - - def main(args: Array[String]): Unit = { - intScanBenchmark(1024 * 1024 * 15) - intStringScanBenchmark(1024 * 1024 * 10) - stringDictionaryScanBenchmark(1024 * 1024 * 10) - partitionTableScanBenchmark(1024 * 1024 * 15) - for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { - stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index adcaf2d76519f..8251ff159e05f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.TestData @@ -33,14 +34,16 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { } test("debugCodegen") { - val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenString(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) } test("debugCodegenStringSeq") { - val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenStringSeq(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) assert(res.length == 2) assert(res.forall{ case (subtree, code) => subtree.contains("Range") && code.contains("Object[]")}) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 51f8c3325fdff..037cc2e3ccad7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.CompactBuffer @@ -254,6 +254,30 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { map.free() } + test("SPARK-24257: insert big values into LongToUnsafeRowMap") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Array[DataType](StringType)) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + + val key = 0L + // the page array is initialized with length 1 << 17 (1M bytes), + // so here we need a value larger than 1 << 18 (2M bytes), to trigger the bug + val bigStr = UTF8String.fromString("x" * (1 << 19)) + + map.append(key, unsafeProj(InternalRow(bigStr))) + map.optimize() + + val resultRow = new UnsafeRow(1) + assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr) + map.free() + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index a3a3f3851e21c..8263c9c81c49e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.metric import java.io.File +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.ui.SQLAppStatusStore import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -504,4 +504,38 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared test("writing data out metrics with dynamic partition: parquet") { testMetricsDynamicPartition("parquet", "parquet", "t1") } + + test("writing metrics from single thread") { + val nAdds = 10 + val acc = new SQLMetric("test", -10) + assert(acc.isZero()) + acc.set(0) + for (i <- 1 to nAdds) acc.add(1) + assert(!acc.isZero()) + assert(nAdds === acc.value) + acc.reset() + assert(acc.isZero()) + } + + test("writing metrics from multiple threads") { + implicit val ec: ExecutionContextExecutor = ExecutionContext.global + val nFutures = 1000 + val nAdds = 100 + val acc = new SQLMetric("test", -10) + assert(acc.isZero() === true) + acc.set(0) + val l = for ( i <- 1 to nFutures ) yield { + Future { + for (j <- 1 to nAdds) acc.add(1) + i + } + } + for (futures <- Future.sequence(l)) { + assert(nFutures === futures.length) + assert(!acc.isZero()) + assert(nFutures * nAdds === acc.value) + acc.reset() + assert(acc.isZero()) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala new file mode 100644 index 0000000000000..07e6034770127 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.python.PythonForeachWriter.UnsafeRowBuffer +import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils + +class PythonForeachWriterSuite extends SparkFunSuite with Eventually { + + testWithBuffer("UnsafeRowBuffer: iterator blocks when no data is available") { b => + b.assertIteratorBlocked() + + b.add(Seq(1)) + b.assertOutput(Seq(1)) + b.assertIteratorBlocked() + + b.add(2 to 100) + b.assertOutput(1 to 100) + b.assertIteratorBlocked() + } + + testWithBuffer("UnsafeRowBuffer: iterator unblocks when all data added") { b => + b.assertIteratorBlocked() + b.add(Seq(1)) + b.assertIteratorBlocked() + + b.allAdded() + b.assertThreadTerminated() + b.assertOutput(Seq(1)) + } + + testWithBuffer( + "UnsafeRowBuffer: handles more data than memory", + memBytes = 5, + sleepPerRowReadMs = 1) { b => + + b.assertIteratorBlocked() + b.add(1 to 2000) + b.assertOutput(1 to 2000) + } + + def testWithBuffer( + name: String, + memBytes: Long = 4 << 10, + sleepPerRowReadMs: Int = 0 + )(f: BufferTester => Unit): Unit = { + + test(name) { + var tester: BufferTester = null + try { + tester = new BufferTester(memBytes, sleepPerRowReadMs) + f(tester) + } finally { + if (tester == null) tester.close() + } + } + } + + + class BufferTester(memBytes: Long, sleepPerRowReadMs: Int) { + private val buffer = { + val mem = new TestMemoryManager(new SparkConf()) + mem.limit(memBytes) + val taskM = new TaskMemoryManager(mem, 0) + new UnsafeRowBuffer(taskM, Utils.createTempDir(), 1) + } + private val iterator = buffer.iterator + private val outputBuffer = new ArrayBuffer[Int] + private val testTimeout = timeout(20.seconds) + private val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val thread = new Thread() { + override def run(): Unit = { + while (iterator.hasNext) { + outputBuffer.synchronized { + outputBuffer += iterator.next().getInt(0) + } + Thread.sleep(sleepPerRowReadMs) + } + } + } + thread.start() + + def add(ints: Seq[Int]): Unit = { + ints.foreach { i => buffer.add(intProj.apply(new GenericInternalRow(Array[Any](i)))) } + } + + def allAdded(): Unit = { buffer.allRowsAdded() } + + def assertOutput(expectedOutput: Seq[Int]): Unit = { + eventually(testTimeout) { + val output = outputBuffer.synchronized { outputBuffer.toArray }.toSeq + assert(output == expectedOutput) + } + } + + def assertIteratorBlocked(): Unit = { + import Thread.State._ + eventually(testTimeout) { + assert(thread.isAlive) + assert(thread.getState == TIMED_WAITING || thread.getState == WAITING) + } + } + + def assertThreadTerminated(): Unit = { + eventually(testTimeout) { assert(!thread.isAlive) } + } + + def close(): Unit = { + thread.interrupt() + thread.join() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index e8420eee7fe9d..3bc36ce55d902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -220,11 +220,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { sink.addBatch(0, 1 to 3) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 12) + assert(plan.stats.sizeInBytes === 36) sink.addBatch(1, 4 to 6) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 24) + assert(plan.stats.sizeInBytes === 72) } ignore("stress test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala new file mode 100644 index 0000000000000..c228740df07c8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.streaming.StreamTest + +class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("SPARK-24156: do not plan a no-data batch again after it has already been planned") { + val inputData = MemoryStream[Int] + val df = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(df)( + AddData(inputData, 10, 11, 12, 13, 14, 15), // Set watermark to 5 + CheckAnswer(), + AddData(inputData, 25), // Set watermark to 15 to make MicroBatchExecution run no-data batch + CheckAnswer((10, 5)), // Last batch should be a no-data batch + StopStream, + Execute { q => + // Delete the last committed batch from the commit log to signify that the last batch + // (a no-data batch) never completed + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purgeAfter(commit - 1) + }, + // Add data before start so that MicroBatchExecution can plan a batch. It should not, + // it should first re-run the incomplete no-data batch and then run a new batch to process + // new data. + AddData(inputData, 30), + StartStream(), + CheckNewAnswer((15, 1)), // This should not throw the error reported in SPARK-24156 + StopStream, + Execute { q => + // Delete the entire commit log + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purge(commit + 1) + }, + AddData(inputData, 50), + StartStream(), + CheckNewAnswer((25, 1), (30, 1)) // This should not throw the error reported in SPARK-24156 + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala new file mode 100644 index 0000000000000..a4233e15e4ffd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import scala.collection.mutable + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming._ + +case class KV(key: Int, value: Long) + +class ForeachBatchSinkSuite extends StreamTest { + import testImplicits._ + + test("foreachBatch with non-stateful query") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, ds.map(_ + 1)) + + import tester._ + testWriter(ds, writer)( + check(in = 1, 2, 3)(out = 3, 4, 5), // out = in + 2 (i.e. 1 in query, 1 in writer) + check(in = 5, 6, 7)(out = 7, 8, 9)) + } + + test("foreachBatch with stateful query in update mode") { + val mem = MemoryStream[Int] + val ds = mem.toDF() + .select($"value" % 2 as "key") + .groupBy("key") + .agg(count("*") as "value") + .toDF.as[KV] + + val tester = new ForeachBatchTester[KV](mem) + val writer = (batchDS: Dataset[KV], batchId: Long) => tester.record(batchId, batchDS) + + import tester._ + testWriter(ds, writer, outputMode = OutputMode.Update)( + check(in = 0)(out = (0, 1L)), + check(in = 1)(out = (1, 1L)), + check(in = 2, 3)(out = (0, 2L), (1, 2L))) + } + + test("foreachBatch with stateful query in complete mode") { + val mem = MemoryStream[Int] + val ds = mem.toDF() + .select($"value" % 2 as "key") + .groupBy("key") + .agg(count("*") as "value") + .toDF.as[KV] + + val tester = new ForeachBatchTester[KV](mem) + val writer = (batchDS: Dataset[KV], batchId: Long) => tester.record(batchId, batchDS) + + import tester._ + testWriter(ds, writer, outputMode = OutputMode.Complete)( + check(in = 0)(out = (0, 1L)), + check(in = 1)(out = (0, 1L), (1, 1L)), + check(in = 2)(out = (0, 2L), (1, 1L))) + } + + test("foreachBatchSink does not affect metric generation") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, ds.map(_ + 1)) + + import tester._ + testWriter(ds, writer)( + check(in = 1, 2, 3)(out = 3, 4, 5), + checkMetrics) + } + + test("throws errors in invalid situations") { + val ds = MemoryStream[Int].toDS + val ex1 = intercept[IllegalArgumentException] { + ds.writeStream.foreachBatch(null.asInstanceOf[(Dataset[Int], Long) => Unit]).start() + } + assert(ex1.getMessage.contains("foreachBatch function cannot be null")) + val ex2 = intercept[AnalysisException] { + ds.writeStream.foreachBatch((_, _) => {}).trigger(Trigger.Continuous("1 second")).start() + } + assert(ex2.getMessage.contains("'foreachBatch' is not supported with continuous trigger")) + val ex3 = intercept[AnalysisException] { + ds.writeStream.foreachBatch((_, _) => {}).partitionBy("value").start() + } + assert(ex3.getMessage.contains("'foreachBatch' does not support partitioning")) + } + + // ============== Helper classes and methods ================= + + private class ForeachBatchTester[T: Encoder](memoryStream: MemoryStream[Int]) { + trait Test + private case class Check(in: Seq[Int], out: Seq[T]) extends Test + private case object CheckMetrics extends Test + + private val recordedOutput = new mutable.HashMap[Long, Seq[T]] + + def testWriter( + ds: Dataset[T], + outputBatchWriter: (Dataset[T], Long) => Unit, + outputMode: OutputMode = OutputMode.Append())(tests: Test*): Unit = { + try { + var expectedBatchId = -1 + val query = ds.writeStream.outputMode(outputMode).foreachBatch(outputBatchWriter).start() + + tests.foreach { + case Check(in, out) => + expectedBatchId += 1 + memoryStream.addData(in) + query.processAllAvailable() + assert(recordedOutput.contains(expectedBatchId)) + val ds: Dataset[T] = spark.createDataset[T](recordedOutput(expectedBatchId)) + checkDataset[T](ds, out: _*) + case CheckMetrics => + assert(query.recentProgress.exists(_.numInputRows > 0)) + } + } finally { + sqlContext.streams.active.foreach(_.stop()) + } + } + + def check(in: Int*)(out: T*): Test = Check(in, out) + def checkMetrics: Test = CheckMetrics + def record(batchId: Long, ds: Dataset[T]): Unit = recordedOutput.put(batchId, ds.collect()) + implicit def conv(x: (Int, Long)): KV = KV(x._1, x._2) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index 03bf71b3f4b78..e60c339bc9cc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -211,14 +211,12 @@ class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAnd try { inputData.addData(10, 11, 12) query.processAllAvailable() - inputData.addData(25) // Advance watermark to 15 seconds - query.processAllAvailable() inputData.addData(25) // Evict items less than previous watermark query.processAllAvailable() // There should be 3 batches and only does the last batch contain a value. val allEvents = ForeachWriterSuite.allEvents() - assert(allEvents.size === 3) + assert(allEvents.size === 4) val expectedEvents = Seq( Seq( ForeachWriterSuite.Open(partition = 0, version = 0), @@ -230,6 +228,10 @@ class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAnd ), Seq( ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Close(None) + ), + Seq( + ForeachWriterSuite.Open(partition = 0, version = 3), ForeachWriterSuite.Process(value = 3), ForeachWriterSuite.Close(None) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index ff14ec38e66a8..bf72e5c99689f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -142,9 +142,9 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() + val tasks = reader.planInputPartitions() assert(tasks.size == 1) - val dataReader = tasks.get(0).createDataReader() + val dataReader = tasks.get(0).createPartitionReader() val data = ArrayBuffer[Row]() while (dataReader.next()) { data.append(dataReader.get()) @@ -159,11 +159,11 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() + val tasks = reader.planInputPartitions() assert(tasks.size == 11) val readData = tasks.asScala - .map(_.createDataReader()) + .map(_.createPartitionReader()) .flatMap { reader => val buf = scala.collection.mutable.ListBuffer[Row]() while (reader.next()) buf.append(reader.get()) @@ -304,17 +304,17 @@ class RateSourceSuite extends StreamTest { val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() + val tasks = reader.planInputPartitions() assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[Row]() tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => + case t: RateStreamContinuousInputPartition => val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + val r = t.createPartitionReader().asInstanceOf[RateStreamContinuousInputPartitionReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index a15a980bb92fd..52e8386f6b1fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.streaming.sources -import java.io.IOException -import java.net.InetSocketAddress +import java.net.{InetSocketAddress, SocketException} import java.nio.ByteBuffer import java.nio.channels.ServerSocketChannel import java.sql.Timestamp @@ -33,9 +32,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} @@ -101,7 +101,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val ref = spark import ref.implicits._ @@ -130,7 +130,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val socket = spark .readStream .format("socket") @@ -216,20 +216,11 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "socket source does not support a user-specified schema")) } - test("no server up") { - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> "0") - intercept[IOException] { - batchReader = provider.createMicroBatchReader( - Optional.empty(), "", new DataSourceOptions(parameters.asJava)) - } - } - test("input row metrics") { serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val ref = spark import ref.implicits._ @@ -256,6 +247,66 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } } + test("verify ServerThread only accepts the first connection") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + + // we are trying to connect to the server once again which should fail + try { + val socket2 = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + testStream(socket2)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + + fail("StreamingQueryException is expected!") + } catch { + case e: StreamingQueryException if e.cause.isInstanceOf[SocketException] => // pass + } + } + } + + /** + * This class tries to mimic the behavior of netcat, so that we can ensure + * TextSocketStream supports netcat, which only accepts the first connection + * and exits the process when the first connection is closed. + * + * Please refer SPARK-24466 for more details. + */ private class ServerThread extends Thread with Logging { private val serverSocketChannel = ServerSocketChannel.open() serverSocketChannel.bind(new InetSocketAddress(0)) @@ -265,36 +316,24 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before override def run(): Unit = { try { + val clientSocketChannel = serverSocketChannel.accept() + + // Close server socket channel immediately to mimic the behavior that + // only first connection will be made and deny any further connections + // Note that the first client socket channel will be available + serverSocketChannel.close() + + clientSocketChannel.configureBlocking(false) + clientSocketChannel.socket().setTcpNoDelay(true) + while (true) { - val clientSocketChannel = serverSocketChannel.accept() - clientSocketChannel.configureBlocking(false) - clientSocketChannel.socket().setTcpNoDelay(true) - - // Check whether remote client is closed but still send data to this closed socket. - // This happens in DataStreamReader where a source will be created to get the schema. - var remoteIsClosed = false - var cnt = 0 - while (cnt < 3 && !remoteIsClosed) { - if (clientSocketChannel.read(ByteBuffer.allocate(1)) != -1) { - cnt += 1 - Thread.sleep(100) - } else { - remoteIsClosed = true - } - } - - if (remoteIsClosed) { - logInfo(s"remote client ${clientSocketChannel.socket()} is closed") - } else { - while (true) { - val line = messageQueue.take() + "\n" - clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) - } - } + val line = messageQueue.take() + "\n" + clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) } } catch { case e: InterruptedException => } finally { + // no harm to call close() again... serverSocketChannel.close() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 65b39f0fbd73d..579a364ebc3e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -55,7 +55,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("versioning and immutability") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( increment) @@ -73,7 +73,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("recovering from files") { - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString def makeStoreRDD( spark: SparkSession, @@ -101,7 +101,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("usage with iterators - only gets and only puts") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val opId = 0 // Returns an iterator of the incremented value made into the store @@ -149,7 +149,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn quietly { val queryRunId = UUID.randomUUID val opId = 0 - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext @@ -189,7 +189,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn .config(sparkConf.setMaster("local-cluster[2, 1, 1024]")) .getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index f3f08839c1d3a..02df45d1b7989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -443,7 +443,8 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val oldCount = statusStore.executionsList().size val expectedAccumValue = 12345 - val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue) + val expectedAccumValue2 = 54321 + val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue, expectedAccumValue2) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan override lazy val executedPlan = physicalPlan @@ -466,10 +467,14 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val execId = statusStore.executionsList().last.executionId val metrics = statusStore.executionMetrics(execId) val driverMetric = physicalPlan.metrics("dummy") + val driverMetric2 = physicalPlan.metrics("dummy2") val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Seq(expectedAccumValue)) + val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, Seq(expectedAccumValue2)) assert(metrics.contains(driverMetric.id)) assert(metrics(driverMetric.id) === expectedValue) + assert(metrics.contains(driverMetric2.id)) + assert(metrics(driverMetric2.id) === expectedValue2) } test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { @@ -562,20 +567,31 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with * A dummy [[org.apache.spark.sql.execution.SparkPlan]] that updates a [[SQLMetrics]] * on the driver. */ -private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExecNode { +private case class MyPlan(sc: SparkContext, expectedValue: Long, expectedValue2: Long) + extends LeafExecNode { + override def sparkContext: SparkContext = sc override def output: Seq[Attribute] = Seq() override val metrics: Map[String, SQLMetric] = Map( - "dummy" -> SQLMetrics.createMetric(sc, "dummy")) + "dummy" -> SQLMetrics.createMetric(sc, "dummy"), + "dummy2" -> SQLMetrics.createMetric(sc, "dummy2")) override def doExecute(): RDD[InternalRow] = { longMetric("dummy") += expectedValue + longMetric("dummy2") += expectedValue2 + + // postDriverMetricUpdates may happen multiple time in a query. + // (normally from different operators, but for the sake of testing, from one operator) + SQLMetrics.postDriverMetricUpdates( + sc, + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + Seq(metrics("dummy"))) SQLMetrics.postDriverMetricUpdates( sc, sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), - metrics.values.toSeq) + Seq(metrics("dummy2"))) sc.emptyRDD } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 772f687526008..f57f07b498261 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1333,4 +1333,11 @@ class ColumnarBatchSuite extends SparkFunSuite { column.close() } + + testVector("WritableColumnVector.reserve(): requested capacity is negative", 1024, ByteType) { + column => + val ex = intercept[RuntimeException] { column.reserve(-1) } + assert(ex.getMessage.contains( + "Cannot reserve additional contiguous bytes in the vectorized reader (integer overflow)")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala new file mode 100644 index 0000000000000..855fe4f4523f2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.test.SQLTestUtils + +class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ + + protected var spark: SparkSession = null + + // Create a new [[SparkSession]] running in local-cluster mode. + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + spark.stop() + spark = null + } + + override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + pairs.foreach { case (k, v) => + SQLConf.get.setConfString(k, v) + } + try f finally { + pairs.foreach { case (k, _) => + SQLConf.get.unsetConf(k) + } + } + } + + test("ReadOnlySQLConf is correctly created at the executor side") { + withSQLConf("spark.sql.x" -> "a") { + val checks = spark.range(10).mapPartitions { _ => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") + }.collect() + assert(checks.forall(_ == true)) + } + } + + test("case-sensitive config should work for json schema inference") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val pathString = path.getCanonicalPath + spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).write.mode("append").json(pathString) + assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) + } + } + } + + test("SPARK-24727 CODEGEN_CACHE_MAX_ENTRIES is correctly referenced at the executor side") { + withSQLConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES.key -> "300") { + val checks = spark.range(10).mapPartitions { _ => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && + conf.getConfString(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES.key) == "300") + }.collect() + assert(checks.forall(_ == true)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5238adce4a699..0389273d6cdfa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,20 +25,21 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils} import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite +class JDBCSuite extends QueryTest with BeforeAndAfter with PrivateMethodTester with SharedSQLContext { import testImplicits._ @@ -238,6 +239,11 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("CREATE TABLE test.partition (THEID INTEGER, `THE ID` INTEGER) " + + "AS SELECT 1, 1") + .executeUpdate() + conn.commit() + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -1093,7 +1099,7 @@ class JDBCSuite extends SparkFunSuite test("SPARK-19318: Connection properties keys should be case-sensitive.") { def testJdbcOptions(options: JDBCOptions): Unit = { // Spark JDBC data source options are case-insensitive - assert(options.table == "t1") + assert(options.tableOrQuery == "t1") // When we convert it to properties, it should be case-sensitive. assert(options.asProperties.size == 3) assert(options.asProperties.get("customkey") == null) @@ -1190,4 +1196,151 @@ class JDBCSuite extends SparkFunSuite assert(sql("select * from people_view").schema === schema) } } + + test("SPARK-23856 Spark jdbc setQueryTimeout option") { + val numJoins = 100 + val longRunningQuery = + s"SELECT t0.NAME AS c0, ${(1 to numJoins).map(i => s"t$i.NAME AS c$i").mkString(", ")} " + + s"FROM test.people t0 ${(1 to numJoins).map(i => s"join test.people t$i").mkString(" ")}" + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("dbtable", s"($longRunningQuery)") + .option("queryTimeout", 1) + .load() + val errMsg = intercept[SparkException] { + df.collect() + }.getMessage + assert(errMsg.contains("Statement was canceled or the session timed out")) + } + + test("SPARK-24327 verify and normalize a partition column based on a JDBC resolved schema") { + def testJdbcParitionColumn(partColName: String, expectedColumnName: String): Unit = { + val df = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PARTITION") + .option("partitionColumn", partColName) + .option("lowerBound", 1) + .option("upperBound", 4) + .option("numPartitions", 3) + .load() + + val quotedPrtColName = testH2Dialect.quoteIdentifier(expectedColumnName) + df.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + s"$quotedPrtColName < 2 or $quotedPrtColName is null", + s"$quotedPrtColName >= 2 AND $quotedPrtColName < 3", + s"$quotedPrtColName >= 3")) + } + } + + testJdbcParitionColumn("THEID", "THEID") + testJdbcParitionColumn("\"THEID\"", "THEID") + withSQLConf("spark.sql.caseSensitive" -> "false") { + testJdbcParitionColumn("ThEiD", "THEID") + } + testJdbcParitionColumn("THE ID", "THE ID") + + def testIncorrectJdbcPartitionColumn(partColName: String): Unit = { + val errMsg = intercept[AnalysisException] { + testJdbcParitionColumn(partColName, "THEID") + }.getMessage + assert(errMsg.contains(s"User-defined partition column $partColName not found " + + "in the JDBC relation:")) + } + + testIncorrectJdbcPartitionColumn("NoExistingColumn") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD")) + } + } + + test("query JDBC option - negative tests") { + val query = "SELECT * FROM test.people WHERE theid = 1" + // load path + val e1 = intercept[RuntimeException] { + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", query) + .option("dbtable", "test.people") + .load() + }.getMessage + assert(e1.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + // jdbc api path + val properties = new Properties() + properties.setProperty(JDBCOptions.JDBC_QUERY_STRING, query) + val e2 = intercept[RuntimeException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", properties).collect() + }.getMessage + assert(e2.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + val e3 = intercept[RuntimeException] { + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', dbtable 'TEST.PEOPLE', + | user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + }.getMessage + assert(e3.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + val e4 = intercept[RuntimeException] { + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", "") + .load() + }.getMessage + assert(e4.contains("Option `query` can not be empty.")) + + // Option query and partitioncolumn are not allowed together. + val expectedErrorMsg = + s""" + |Options 'query' and 'partitionColumn' can not be specified together. + |Please define the query using `dbtable` option instead and make sure to qualify + |the partition columns using the supplied subquery alias to resolve any ambiguity. + |Example : + |spark.read.format("jdbc") + | .option("dbtable", "(select c1, c2 from t1) as subq") + | .option("partitionColumn", "subq.c1" + | .load() + """.stripMargin + val e5 = intercept[RuntimeException] { + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', user 'testUser', password 'testPass', + | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') + """.stripMargin.replaceAll("\n", " ")) + }.getMessage + assert(e5.contains(expectedErrorMsg)) + } + + test("query JDBC option") { + val query = "SELECT name, theid FROM test.people WHERE theid = 1" + // query option to pass on the query string. + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", query) + .load() + checkAnswer( + df, + Row("fred", 1) :: Nil) + + // query option in the create table path. + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + checkAnswer( + sql("select name, theid from queryOption"), + Row("fred", 1) :: Nil) + + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 1985b1dc82879..b751ec2de4825 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -293,13 +293,23 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { test("save errors if dbtable is not specified") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val e = intercept[RuntimeException] { + val e1 = intercept[RuntimeException] { df.write.format("jdbc") .option("url", url1) .options(properties.asScala) .save() }.getMessage - assert(e.contains("Option 'dbtable' is required")) + assert(e1.contains("Option 'dbtable' or 'query' is required")) + + val e2 = intercept[RuntimeException] { + df.write.format("jdbc") + .option("url", url1) + .options(properties.asScala) + .option("query", "select * from TEST.SAVETEST") + .save() + }.getMessage + val msg = "Option 'dbtable' is required. Option 'query' is not applicable while writing." + assert(e2.contains(msg)) } test("save errors if wrong user/password combination") { @@ -515,4 +525,22 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { }.getMessage assert(e.contains("NULL not allowed for column \"NAME\"")) } + + ignore("SPARK-23856 Spark jdbc setQueryTimeout option") { + // The behaviour of the option `queryTimeout` depends on how JDBC drivers implement the API + // `setQueryTimeout`. For example, in the h2 JDBC driver, `executeBatch` invokes multiple + // INSERT queries in a batch and `setQueryTimeout` means that the driver checks the timeout + // of each query. In the PostgreSQL JDBC driver, `setQueryTimeout` means that the driver + // checks the timeout of an entire batch in a driver side. So, the test below fails because + // this test suite depends on the h2 JDBC driver and the JDBC write path internally + // uses `executeBatch`. + val errMsg = intercept[SparkException] { + spark.range(10000000L).selectExpr("id AS k", "id AS v").coalesce(1).write + .mode(SaveMode.Overwrite) + .option("queryTimeout", 1) + .option("batchsize", Int.MaxValue) + .jdbc(url1, "TEST.TIMEOUTTEST", properties) + }.getMessage + assert(errMsg.contains("Statement was canceled or the session timed out")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index fb61fa716b946..a9414200e70f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -22,10 +22,11 @@ import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ @@ -52,6 +53,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g") } yield (i % 5, s, i % 13)).toDF("i", "j", "k") + // number of buckets that doesn't yield empty buckets when bucketing on column j on df/nullDF + // empty buckets before filtering might hide bugs in pruning logic + private val NumBucketsForPruningDF = 7 + private val NumBucketsForPruningNullDf = 5 + test("read bucketed data") { withTable("bucketed_table") { df.write @@ -90,32 +96,37 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { - val strategy = DataSourceStrategy(spark.sessionState.conf) val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec // Limit: bucket pruning only works when the bucket column has one and only one column assert(bucketColumnNames.length == 1) val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head) val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) - val matchedBuckets = new BitSet(numBuckets) - bucketValues.foreach { value => - matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) - } // Filter could hide the bug in bucket pruning. Thus, skipping all the filters val plan = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan val rdd = plan.find(_.isInstanceOf[DataSourceScanExec]) assert(rdd.isDefined, plan) - val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => - if (matchedBuckets.get(index % numBuckets) && iter.nonEmpty) Iterator(index) else Iterator() + // if nothing should be pruned, skip the pruning test + if (bucketValues.nonEmpty) { + val matchedBuckets = new BitSet(numBuckets) + bucketValues.foreach { value => + matchedBuckets.set(BucketingUtils.getBucketIdFromValue(bucketColumn, numBuckets, value)) + } + val invalidBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + // return indexes of partitions that should have been pruned and are not empty + if (!matchedBuckets.get(index % numBuckets) && iter.nonEmpty) { + Iterator(index) + } else { + Iterator() + } + }.collect() + + if (invalidBuckets.nonEmpty) { + fail(s"Buckets ${invalidBuckets.mkString(",")} should have been pruned from:\n$plan") + } } - // TODO: These tests are not testing the right columns. -// // checking if all the pruned buckets are empty -// val invalidBuckets = checkedResult.collect().toList -// if (invalidBuckets.nonEmpty) { -// fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan") -// } checkAnswer( bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"), @@ -125,7 +136,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables with bucket pruning filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -155,13 +166,21 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { bucketValues = Seq(j, j + 1, j + 2, j + 3), filterCondition = $"j".isin(j, j + 1, j + 2, j + 3), df) + + // Case 4: InSet + val inSetExpr = expressions.InSet($"j".expr, Set(j, j + 1, j + 2, j + 3).map(lit(_).expr)) + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1, j + 2, j + 3), + filterCondition = Column(inSetExpr), + df) } } } test("read non-partitioning bucketed tables with bucket pruning filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -181,7 +200,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables having null in bucketing key") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningNullDf val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here nullDF.write @@ -208,7 +227,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables having composite filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -229,7 +248,62 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { bucketValues = j :: Nil, filterCondition = $"j" === j && $"i" > j % 5, df) + + // check multiple bucket values OR condition + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1), + filterCondition = $"j" === j || $"j" === (j + 1), + df) + + // check bucket value and none bucket value OR condition + checkPrunedAnswers( + bucketSpec, + bucketValues = Nil, + filterCondition = $"j" === j || $"i" === 0, + df) + + // check AND condition in complex expression + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j), + filterCondition = ($"i" === 0 || $"k" > $"j") && $"j" === j, + df) + } + } + } + + test("read bucketed table without filters") { + withTable("bucketed_table") { + val numBuckets = NumBucketsForPruningDF + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") + val plan = bucketedDataFrame.queryExecution.executedPlan + val rdd = plan.find(_.isInstanceOf[DataSourceScanExec]) + assert(rdd.isDefined, plan) + + val emptyBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + // return indexes of empty partitions + if (iter.isEmpty) { + Iterator(index) + } else { + Iterator() + } + }.collect() + + if (emptyBuckets.nonEmpty) { + fail(s"Buckets ${emptyBuckets.mkString(",")} should not have been pruned from:\n$plan") } + + checkAnswer( + bucketedDataFrame.orderBy("i", "j", "k"), + df.orderBy("i", "j", "k")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 93f3efe2ccc4a..5ff1ea84d9a7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -60,7 +60,10 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { test("specify sorting columns without bucketing columns") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") - intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt")) + val e = intercept[AnalysisException] { + df.write.sortBy("j").saveAsTable("tt") + } + assert(e.getMessage == "sortBy must be used together with bucketBy;") } test("sorting by non-orderable column") { @@ -74,7 +77,16 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { df.write.bucketBy(2, "i").parquet("/tmp/path") } - assert(e.getMessage == "'save' does not support bucketing right now;") + assert(e.getMessage == "'save' does not support bucketBy right now;") + } + + test("write bucketed and sorted data using save()") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + + val e = intercept[AnalysisException] { + df.write.bucketBy(2, "i").sortBy("i").parquet("/tmp/path") + } + assert(e.getMessage == "'save' does not support bucketBy and sortBy right now;") } test("write bucketed data using insertInto()") { @@ -83,7 +95,16 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { df.write.bucketBy(2, "i").insertInto("tt") } - assert(e.getMessage == "'insertInto' does not support bucketing right now;") + assert(e.getMessage == "'insertInto' does not support bucketBy right now;") + } + + test("write bucketed and sorted data using insertInto()") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + + val e = intercept[AnalysisException] { + df.write.bucketBy(2, "i").sortBy("i").insertInto("tt") + } + assert(e.getMessage == "'insertInto' does not support bucketBy and sortBy right now;") } private lazy val df = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index fef01c860db6e..438d5d8176b8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -20,12 +20,36 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class SimpleInsertSource extends SchemaRelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + SimpleInsert(schema)(sqlContext.sparkSession) + } +} + +case class SimpleInsert(userSpecifiedSchema: StructType)(@transient val sparkSession: SparkSession) + extends BaseRelation with InsertableRelation { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + override def schema: StructType = userSpecifiedSchema + + override def insert(input: DataFrame, overwrite: Boolean): Unit = { + input.collect + } +} + class InsertSuite extends DataSourceTest with SharedSQLContext { import testImplicits._ @@ -520,4 +544,29 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } } } + + test("SPARK-24583 Wrong schema type in InsertIntoDataSourceCommand") { + withTable("test_table") { + val schema = new StructType() + .add("i", LongType, false) + .add("s", StringType, false) + val newTable = CatalogTable( + identifier = TableIdentifier("test_table", None), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + compressed = false, + properties = Map.empty), + schema = schema, + provider = Some(classOf[SimpleInsertSource].getName)) + + spark.sessionState.catalog.createTable(newTable, false) + + sql("INSERT INTO TABLE test_table SELECT 1, 'a'") + sql("INSERT INTO TABLE test_table SELECT 2, null") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e0a53272cd222..e96cd4500458d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -323,21 +323,22 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23315: get output from canonicalized data source v2 related plans") { - def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = { + def checkCanonicalizedOutput( + df: DataFrame, logicalNumOutput: Int, physicalNumOutput: Int): Unit = { val logical = df.queryExecution.optimizedPlan.collect { case d: DataSourceV2Relation => d }.head - assert(logical.canonicalized.output.length == numOutput) + assert(logical.canonicalized.output.length == logicalNumOutput) val physical = df.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => d }.head - assert(physical.canonicalized.output.length == numOutput) + assert(physical.canonicalized.output.length == physicalNumOutput) } val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() - checkCanonicalizedOutput(df, 2) - checkCanonicalizedOutput(df.select('i), 1) + checkCanonicalizedOutput(df, 2, 2) + checkCanonicalizedOutput(df.select('i), 2, 1) } } @@ -346,8 +347,8 @@ class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { - java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5)) + override def planInputPartitions(): JList[InputPartition[Row]] = { + java.util.Arrays.asList(new SimpleInputPartition(0, 5)) } } @@ -359,20 +360,21 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { - java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10)) + override def planInputPartitions(): JList[InputPartition[Row]] = { + java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SimpleDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[Row] - with DataReader[Row] { +class SimpleInputPartition(start: Int, end: Int) + extends InputPartition[Row] + with InputPartitionReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end) + override def createPartitionReader(): InputPartitionReader[Row] = + new SimpleInputPartition(start, end) override def next(): Boolean = { current += 1 @@ -413,21 +415,21 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { val lowerBound = filters.collect { case GreaterThan("i", v: Int) => v }.headOption - val res = new ArrayList[DataReaderFactory[Row]] + val res = new ArrayList[InputPartition[Row]] if (lowerBound.isEmpty) { - res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema)) - res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) + res.add(new AdvancedInputPartition(0, 5, requiredSchema)) + res.add(new AdvancedInputPartition(5, 10, requiredSchema)) } else if (lowerBound.get < 4) { - res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) + res.add(new AdvancedInputPartition(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedInputPartition(5, 10, requiredSchema)) } else if (lowerBound.get < 9) { - res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 10, requiredSchema)) + res.add(new AdvancedInputPartition(lowerBound.get + 1, 10, requiredSchema)) } res @@ -437,13 +439,13 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) - extends DataReaderFactory[Row] with DataReader[Row] { +class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) + extends InputPartition[Row] with InputPartitionReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = { - new AdvancedDataReaderFactory(start, end, requiredSchema) + override def createPartitionReader(): InputPartitionReader[Row] = { + new AdvancedInputPartition(start, end, requiredSchema) } override def close(): Unit = {} @@ -468,24 +470,24 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsScanUnsafeRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { - java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5), - new UnsafeRowDataReaderFactory(5, 10)) + override def planUnsafeInputPartitions(): JList[InputPartition[UnsafeRow]] = { + java.util.Arrays.asList(new UnsafeRowInputPartitionReader(0, 5), + new UnsafeRowInputPartitionReader(5, 10)) } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class UnsafeRowDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] { +class UnsafeRowInputPartitionReader(start: Int, end: Int) + extends InputPartition[UnsafeRow] with InputPartitionReader[UnsafeRow] { private val row = new UnsafeRow(2) row.pointTo(new Array[Byte](8 * 3), 8 * 3) private var current = start - 1 - override def createDataReader(): DataReader[UnsafeRow] = this + override def createPartitionReader(): InputPartitionReader[UnsafeRow] = this override def next(): Boolean = { current += 1 @@ -503,7 +505,7 @@ class UnsafeRowDataReaderFactory(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { class Reader(val readSchema: StructType) extends DataSourceReader { - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = + override def planInputPartitions(): JList[InputPartition[Row]] = java.util.Collections.emptyList() } @@ -516,16 +518,17 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsScanColumnarBatch { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { - java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90)) + override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = { + java.util.Arrays.asList( + new BatchInputPartitionReader(0, 50), new BatchInputPartitionReader(50, 90)) } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class BatchDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] { +class BatchInputPartitionReader(start: Int, end: Int) + extends InputPartition[ColumnarBatch] with InputPartitionReader[ColumnarBatch] { private final val BATCH_SIZE = 20 private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) @@ -534,7 +537,7 @@ class BatchDataReaderFactory(start: Int, end: Int) private var current = start - override def createDataReader(): DataReader[ColumnarBatch] = this + override def createPartitionReader(): InputPartitionReader[ColumnarBatch] = this override def next(): Boolean = { i.reset() @@ -568,11 +571,11 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsReportPartitioning { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { // Note that we don't have same value of column `a` across partitions. java.util.Arrays.asList( - new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)), - new SpecificDataReaderFactory(Array(2, 4, 4), Array(6, 2, 2))) + new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificInputPartitionReader(Array(2, 4, 4), Array(6, 2, 2))) } override def outputPartitioning(): Partitioning = new MyPartitioning @@ -590,14 +593,14 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) - extends DataReaderFactory[Row] - with DataReader[Row] { +class SpecificInputPartitionReader(i: Array[Int], j: Array[Int]) + extends InputPartition[Row] + with InputPartitionReader[Row] { assert(i.length == j.length) private var current = -1 - override def createDataReader(): DataReader[Row] = this + override def createPartitionReader(): InputPartitionReader[Row] = this override def next(): Boolean = { current += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a5007fa321359..1334cf71ae988 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -45,7 +45,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Reader(path: String, conf: Configuration) extends DataSourceReader { override def readSchema(): StructType = schema - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -54,9 +54,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS name.startsWith("_") || name.startsWith(".") }.map { f => val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVDataReaderFactory( + new SimpleCSVInputPartitionReader( f.getPath.toUri.toString, - serializableConf): DataReaderFactory[Row] + serializableConf): InputPartition[Row] }.toList.asJava } else { Collections.emptyList() @@ -156,14 +156,14 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } -class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) - extends DataReaderFactory[Row] with DataReader[Row] { +class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration) + extends InputPartition[Row] with InputPartitionReader[Row] { @transient private var lines: Iterator[String] = _ @transient private var currentLine: String = _ @transient private var inputStream: FSDataInputStream = _ - override def createDataReader(): DataReader[Row] = { + override def createPartitionReader(): InputPartitionReader[Row] = { val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) inputStream = fs.open(filePath) @@ -209,10 +209,10 @@ class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: Serializable override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) new SimpleCSVDataWriter(fs, filePath) } @@ -245,10 +245,10 @@ class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: Seriali override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) new InternalRowCSVDataWriter(fs, filePath) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index d6bef9ce07379..58ed9790ea123 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -18,17 +18,22 @@ package org.apache.spark.sql.streaming import java.{util => ju} +import java.io.File import java.text.SimpleDateFormat import java.util.{Calendar, Date} +import org.apache.commons.io.FileUtils import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ +import org.apache.spark.util.Utils class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -137,20 +142,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche assert(e.get("watermark") === formatTimestamp(5)) }, AddData(inputData2, 25), - CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(25)) - assert(e.get("min") === formatTimestamp(25)) - assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(5)) - }, - AddData(inputData2, 25), CheckAnswer((10, 3)), assertEventStats { e => assert(e.get("max") === formatTimestamp(25)) assert(e.get("min") === formatTimestamp(25)) assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(15)) + assert(e.get("watermark") === formatTimestamp(5)) } ) } @@ -167,15 +164,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckNewAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - assertNumStateRows(3), - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10, 5)), + CheckNewAnswer((10, 5)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -193,15 +187,15 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation, OutputMode.Update)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch((10, 5), (15, 1)), + CheckNewAnswer((10, 5), (15, 1)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch((25, 1)), - assertNumStateRows(3), + CheckNewAnswer((25, 1)), + assertNumStateRows(2), AddData(inputData, 10, 25), // Ignore 10 as its less than watermark - CheckLastBatch((25, 2)), + CheckNewAnswer((25, 2)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -251,56 +245,25 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(df)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - StopStream, - StartStream(), - CheckLastBatch(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckLastBatch((10, 5)), + CheckAnswer((10, 5)), StopStream, AssertOnQuery { q => // purge commit and clear the sink - val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + 1L + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) q.commitLog.purge(commit) q.sink.asInstanceOf[MemorySink].clear() true }, StartStream(), - CheckLastBatch((10, 5)), // Recompute last batch and re-evict timestamp 10 - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), + AddData(inputData, 10, 27, 30), // Advance watermark to 20 seconds, 10 should be ignored + CheckAnswer((15, 1)), StopStream, - StartStream(), // Watermark should still be 15 seconds - AddData(inputData, 17), - CheckLastBatch(), // We still do not see next batch - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), - AddData(inputData, 30), // Evict items less than previous watermark. - CheckLastBatch((15, 2)) // Ensure we see next window - ) - } - - test("dropping old data") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - - testStream(windowedAggregation)( - AddData(inputData, 10, 11, 12), - CheckAnswer(), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckAnswer((10, 3)), - AddData(inputData, 10), // 10 is later than 15 second watermark - CheckAnswer((10, 3)), - AddData(inputData, 25), - CheckAnswer((10, 3)) // Should not emit an incorrect partial result. + StartStream(), + AddData(inputData, 17), // Watermark should still be 20 seconds, 17 should be ignored + CheckAnswer((15, 1)), + AddData(inputData, 40), // Advance watermark to 30 seconds, emit first data 25 + CheckNewAnswer((25, 2)) ) } @@ -421,8 +384,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche AddData(inputData, 10), CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. CheckAnswer((10, 1)) ) } @@ -501,8 +462,165 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche } } + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckNewAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + // Check if there is new answer if flag is set, no new answer otherwise + if (flag) CheckNewAnswer((10, 5)) else CheckNewAnswer() + ) + } + + testWithFlag(true) + testWithFlag(false) + } + + test("MultipleWatermarkPolicy: max") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 15), // max(20 - 10, 30 - 15) = 15 + StopStream, + StartStream(), + checkWatermark(input1, 15), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input1, 115), // max(120 - 10, 130 - 15) = 115, policy recovered correctly + AddData(input1, 150), + CheckLastBatch(150), + checkWatermark(input1, 140) // should advance even if one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: min") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "min") { + testStream(dfWithMultipleWatermarks(input1, input2))( + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 10), // min(20 - 10, 30 - 15) = 10 + StopStream, + StartStream(), + checkWatermark(input1, 10), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input2, 110), // min(120 - 10, 130 - 15) = 110, policy recovered correctly + AddData(input2, 150), + CheckLastBatch(150), + checkWatermark(input2, 110) // does not advance when only one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: recovery from checkpoints ignores session conf") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val checkpointDir = Utils.createTempDir().getCanonicalFile + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 15) // max(20 - 10, 30 - 15) = 15 + ) + } + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "min") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + checkWatermark(input1, 15), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input1, 115), // max(120 - 10, 130 - 15) = 115, policy recovered correctly + AddData(input1, 150), + CheckLastBatch(150), + checkWatermark(input1, 140) // should advance even if one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: recovery from Spark ver 2.3.1 checkpoints ensures min policy") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + input1.addData(20) + input2.addData(30) + input1.addData(10) + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + Execute { _.processAllAvailable() }, + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input2, 110), // should calculate 'min' even if session conf has 'max' policy + AddData(input2, 150), + CheckLastBatch(150), + checkWatermark(input2, 110) + ) + } + } + + test("MultipleWatermarkPolicy: fail on incorrect conf values") { + val invalidValues = Seq("", "random") + invalidValues.foreach { value => + val e = intercept[IllegalArgumentException] { + spark.conf.set(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key, value) + } + assert(e.getMessage.toLowerCase.contains("valid values are 'min' and 'max'")) + } + } + + private def dfWithMultipleWatermarks( + input1: MemoryStream[Int], + input2: MemoryStream[Int]): Dataset[_] = { + val df1 = input1.toDF + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + val df2 = input2.toDF + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "15 seconds") + df1.union(df2).select($"eventTime".cast("int")) + } + + private def checkWatermark(input: MemoryStream[Int], watermark: Long) = Execute { q => + input.addData(1) + q.processAllAvailable() + assert(q.lastProgress.eventTime.get("watermark") == formatTimestamp(watermark)) + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + q.processAllAvailable() + val progressWithData = q.recentProgress.lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index cf41d7e0e4fe1..ed53def556cb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -279,13 +279,10 @@ class FileStreamSinkSuite extends StreamTest { check() // nothing emitted yet addTimestamp(104, 123) // watermark = 90 before this, watermark = 123 - 10 = 113 after this - check() // nothing emitted yet + check((100L, 105L) -> 2L) // no-data-batch emits results on 100-105, addTimestamp(140) // wm = 113 before this, emit results on 100-105, wm = 130 after this - check((100L, 105L) -> 2L) - - addTimestamp(150) // wm = 130s before this, emit results on 120-125, wm = 150 after this - check((100L, 105L) -> 2L, (120L, 125L) -> 1L) + check((100L, 105L) -> 2L, (120L, 125L) -> 1L) // no-data-batch emits results on 120-125 } finally { if (query != null) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index b1416bff87ee7..988c8e6753e25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -615,20 +615,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), + CheckNewAnswer(("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")), + CheckNewAnswer(("a", "1"), ("c", "1")), assertNumStateRows(total = 3, updated = 2) ) } @@ -657,15 +657,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a", "a", "b"), - CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "1"), ("a", "2"), ("b", "1")), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), + CheckNewAnswer(("b", "2")), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")) + CheckNewAnswer(("a", "1"), ("c", "1")) ) } @@ -694,22 +694,22 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Complete)( AddData(inputData, "a"), - CheckLastBatch(("a", 1)), + CheckNewAnswer(("a", 1)), AddData(inputData, "a", "b"), // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 - CheckLastBatch(("a", 2), ("b", 1)), + CheckNewAnswer(("a", 2), ("b", 1)), StopStream, StartStream(), AddData(inputData, "a", "b"), // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; // so increment a and b by 1 - CheckLastBatch(("a", 3), ("b", 2)), + CheckNewAnswer(("a", 3), ("b", 2)), StopStream, StartStream(), AddData(inputData, "a", "c"), // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; // so increment a and c by 1 - CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + CheckNewAnswer(("a", 4), ("b", 2), ("c", 1)) ) } @@ -729,8 +729,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } test("flatMapGroupsWithState - streaming with processing time timeout") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCannotGetWatermark { state.getCurrentWatermarkMs() } @@ -757,17 +757,17 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "a"), AdvanceManualClock(1 * 1000), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "b"), AdvanceManualClock(1 * 1000), - CheckLastBatch(("b", "1")), + CheckNewAnswer(("b", "1")), assertNumStateRows(total = 2, updated = 1), AddData(inputData, "b"), AdvanceManualClock(10 * 1000), - CheckLastBatch(("a", "-1"), ("b", "2")), + CheckNewAnswer(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, @@ -775,38 +775,42 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest AddData(inputData, "c"), AdvanceManualClock(11 * 1000), - CheckLastBatch(("b", "-1"), ("c", "1")), + CheckNewAnswer(("b", "-1"), ("c", "1")), assertNumStateRows(total = 1, updated = 2), - AddData(inputData, "c"), - AdvanceManualClock(20 * 1000), - CheckLastBatch(("c", "2")), - assertNumStateRows(total = 1, updated = 1) + AdvanceManualClock(12 * 1000), + AssertOnQuery { _ => clock.getTimeMillis() == 35000 }, + Execute { q => + failAfter(streamingTimeout) { + while (q.lastProgress.timestamp != "1970-01-01T00:00:35.000Z") { + Thread.sleep(1) + } + } + }, + CheckNewAnswer(("c", "-1")), + assertNumStateRows(total = 0, updated = 0) ) } test("flatMapGroupsWithState - streaming with event time timeout + watermark") { - // Function to maintain the max event time - // Returns the max event time in the state, or -1 if the state was removed by timeout + // Function to maintain the max event time as state and set the timeout timestamp based on the + // current max event time seen. It returns the max event time in the state, or -1 if the state + // was removed by timeout. val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } - val timeoutDelay = 5 - if (key != "a") { - Iterator.empty + val timeoutDelaySec = 5 + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) } else { - if (state.hasTimedOut) { - state.remove() - Iterator((key, -1)) - } else { - val valuesSeq = values.toSeq - val maxEventTime = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) - val timeoutTimestampMs = maxEventTime + timeoutDelay - state.update(maxEventTime) - state.setTimeoutTimestamp(timeoutTimestampMs * 1000) - Iterator((key, maxEventTime.toInt)) - } + val valuesSeq = values.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec + state.update(maxEventTimeSec) + state.setTimeoutTimestamp(timeoutTimestampSec * 1000) + Iterator((key, maxEventTimeSec.toInt)) } } val inputData = MemoryStream[(String, Int)] @@ -819,15 +823,23 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) testStream(result, Update)( - StartStream(Trigger.ProcessingTime("1 second")), - AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... - CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s + StartStream(), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" - CheckLastBatch(), // No output as data should get filtered by watermark - AddData(inputData, ("dummy", 35)), // Set watermark = 35 - 10 = 25s - CheckLastBatch(), // No output as no data for "a" - AddData(inputData, ("a", 24)), // Add data older than watermark, should be ignored - CheckLastBatch(("a", -1)) // State for "a" should timeout and emit -1 + CheckNewAnswer(), // No output as data should get filtered by watermark + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 ) } @@ -856,20 +868,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 - CheckLastBatch(("a", "-1"), ("b", "2")), + CheckNewAnswer(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("c", "1")), + CheckNewAnswer(("a", "1"), ("c", "1")), assertNumStateRows(total = 3, updated = 2) ) } @@ -920,15 +932,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( setFailInTask(false), AddData(inputData, "a"), - CheckLastBatch(("a", 1L)), + CheckNewAnswer(("a", 1L)), AddData(inputData, "a"), - CheckLastBatch(("a", 2L)), + CheckNewAnswer(("a", 2L)), setFailInTask(true), AddData(inputData, "a"), ExpectFailure[SparkException](), // task should fail but should not increment count setFailInTask(false), StartStream(), - CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count + CheckNewAnswer(("a", 3L)) // task should not fail, and should show correct count ) } @@ -938,7 +950,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch("a"), + CheckNewAnswer("a"), AssertOnQuery(_.lastExecution.executedPlan.outputPartitioning === UnknownPartitioning(0)) ) } @@ -1000,7 +1012,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, ("a", 1L)), AdvanceManualClock(1 * 1000), - CheckLastBatch(("a", "1")) + CheckNewAnswer(("a", "1")) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index 368c4604dfca8..e45f9d3e2e97b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -17,20 +17,58 @@ package org.apache.spark.sql.streaming +import org.apache.spark.sql.execution.streaming.StreamExecution + trait StateStoreMetricsTest extends StreamTest { + private var lastCheckedRecentProgressIndex = -1 + private var lastQuery: StreamExecution = null + + override def beforeEach(): Unit = { + super.beforeEach() + lastCheckedRecentProgressIndex = -1 + } + def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = AssertOnQuery(s"Check total state rows = $total, updated state rows = $updated") { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert( - progressWithData.stateOperators.map(_.numRowsTotal) === total, - "incorrect total rows") - assert( - progressWithData.stateOperators.map(_.numRowsUpdated) === updated, - "incorrect updates rows") + val recentProgress = q.recentProgress + require(recentProgress.nonEmpty, "No progress made, cannot check num state rows") + require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention, + "This test assumes that all progresses are present in q.recentProgress but " + + "some may have been dropped due to retention limits") + + if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1 + lastQuery = q + + val numStateOperators = recentProgress.last.stateOperators.length + val progressesSinceLastCheck = recentProgress + .slice(lastCheckedRecentProgressIndex + 1, recentProgress.length) + .filter(_.stateOperators.length == numStateOperators) + + val allNumUpdatedRowsSinceLastCheck = + progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated)) + + lazy val debugString = "recent progresses:\n" + + progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n") + + val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal) + assert(numTotalRows === total, s"incorrect total rows, $debugString") + + val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators) + assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString") + + lastCheckedRecentProgressIndex = recentProgress.length - 1 true } def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = assertNumStateRows(Seq(total), Seq(updated)) + + def arraySum(arraySeq: Seq[Array[Long]], arrayLength: Int): Seq[Long] = { + if (arraySeq.isEmpty) return Seq.fill(arrayLength)(0L) + + assert(arraySeq.forall(_.length == arrayLength), + "Arrays are of different lengths:\n" + arraySeq.map(_.toSeq).mkString("\n")) + (0 until arrayLength).map { index => arraySeq.map(_.apply(index)).sum } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c1ec1eba69fb2..ca38f04136c7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -805,6 +805,114 @@ class StreamSuite extends StreamTest { } } + test("streaming limit without state") { + val inputData1 = MemoryStream[Int] + testStream(inputData1.toDF().limit(0))( + AddData(inputData1, 1 to 8: _*), + CheckAnswer()) + + val inputData2 = MemoryStream[Int] + testStream(inputData2.toDF().limit(4))( + AddData(inputData2, 1 to 8: _*), + CheckAnswer(1 to 4: _*)) + } + + test("streaming limit with state") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().limit(4))( + AddData(inputData, 1 to 2: _*), + CheckAnswer(1 to 2: _*), + AddData(inputData, 3 to 6: _*), + CheckAnswer(1 to 4: _*), + AddData(inputData, 7 to 9: _*), + CheckAnswer(1 to 4: _*)) + } + + test("streaming limit with other operators") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().where("value % 2 = 1").limit(4))( + AddData(inputData, 1 to 5: _*), + CheckAnswer(1, 3, 5), + AddData(inputData, 6 to 9: _*), + CheckAnswer(1, 3, 5, 7), + AddData(inputData, 10 to 12: _*), + CheckAnswer(1, 3, 5, 7)) + } + + test("streaming limit with multiple limits") { + val inputData1 = MemoryStream[Int] + testStream(inputData1.toDF().limit(4).limit(2))( + AddData(inputData1, 1), + CheckAnswer(1), + AddData(inputData1, 2 to 8: _*), + CheckAnswer(1, 2)) + + val inputData2 = MemoryStream[Int] + testStream(inputData2.toDF().limit(4).limit(100).limit(3))( + AddData(inputData2, 1, 2), + CheckAnswer(1, 2), + AddData(inputData2, 3 to 8: _*), + CheckAnswer(1 to 3: _*)) + } + + test("streaming limit in complete mode") { + val inputData = MemoryStream[Int] + val limited = inputData.toDF().limit(5).groupBy("value").count() + testStream(limited, OutputMode.Complete())( + AddData(inputData, 1 to 3: _*), + CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)), + AddData(inputData, 1 to 9: _*), + CheckAnswer(Row(1, 2), Row(2, 2), Row(3, 2), Row(4, 1), Row(5, 1))) + } + + test("streaming limits in complete mode") { + val inputData = MemoryStream[Int] + val limited = inputData.toDF().limit(4).groupBy("value").count().orderBy("value").limit(3) + testStream(limited, OutputMode.Complete())( + AddData(inputData, 1 to 9: _*), + CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)), + AddData(inputData, 2 to 6: _*), + CheckAnswer(Row(1, 1), Row(2, 2), Row(3, 2))) + } + + test("streaming limit in update mode") { + val inputData = MemoryStream[Int] + val e = intercept[AnalysisException] { + testStream(inputData.toDF().limit(5), OutputMode.Update())( + AddData(inputData, 1 to 3: _*) + ) + } + assert(e.getMessage.contains( + "Limits are not supported on streaming DataFrames/Datasets in Update output mode")) + } + + test("streaming limit in multiple partitions") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().repartition(2).limit(7))( + AddData(inputData, 1 to 10: _*), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 10)), + false), + AddData(inputData, 11 to 20: _*), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 10)), + false)) + } + + test("streaming limit in multiple partitions by column") { + val inputData = MemoryStream[(Int, Int)] + val df = inputData.toDF().repartition(2, $"_2").limit(7) + testStream(df)( + AddData(inputData, (1, 0), (2, 0), (3, 1), (4, 1)), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 4 && rows.forall(r => r.getInt(0) <= 4)), + false), + AddData(inputData, (5, 0), (6, 0), (7, 1), (8, 1)), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 8)), + false)) + } + for (e <- Seq( new InterruptedException, new InterruptedIOException, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index af0268fa47871..4c3fd58cb2e45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -192,14 +193,30 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class CheckAnswerRowsContains(expectedAnswer: Seq[Row], lastOnly: Boolean = false) extends StreamAction with StreamMustBeRunning { override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" - private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" + private def operatorName = if (lastOnly) "CheckLastBatchContains" else "CheckAnswerContains" } case class CheckAnswerRowsByFunc( globalCheckFunction: Seq[Row] => Unit, lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"$operatorName" - private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" + override def toString: String = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" + } + + case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"CheckNewAnswer: ${expectedAnswer.mkString(",")}" + } + + object CheckNewAnswer { + def apply(): CheckNewAnswerRows = CheckNewAnswerRows(Seq.empty) + + def apply[A: Encoder](data: A, moreData: A*): CheckNewAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() + CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d)))) + } + + def apply(rows: Row*): CheckNewAnswerRows = CheckNewAnswerRows(rows) } /** Stops the stream. It must currently be running. */ @@ -275,7 +292,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Execute arbitrary code */ object Execute { def apply(func: StreamExecution => Any): AssertOnQuery = - AssertOnQuery(query => { func(query); true }) + AssertOnQuery(query => { func(query); true }, "Execute") } object AwaitEpoch { @@ -435,13 +452,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be """.stripMargin) } - def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { + var lastFetchedMemorySinkLastBatchId: Long = -1 + + def fetchStreamAnswer( + currentStream: StreamExecution, + lastOnly: Boolean = false, + sinceLastFetchOnly: Boolean = false) = { + verify( + !(lastOnly && sinceLastFetchOnly), "both lastOnly and sinceLastFetchOnly cannot be true") verify(currentStream != null, "stream not running") // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { currentStream.awaitOffset(sourceIndex, offset) + // Make sure all processing including no-data-batches have been executed + if (!currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + currentStream.processAllAvailable() + } } } @@ -463,14 +491,21 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - val (latestBatchData, allData) = sink match { - case s: MemorySink => (s.latestBatchData, s.allData) - case s: MemorySinkV2 => (s.latestBatchData, s.allData) - } - try if (lastOnly) latestBatchData else allData catch { + val rows = try { + if (sinceLastFetchOnly) { + if (sink.latestBatchId.getOrElse(-1L) < lastFetchedMemorySinkLastBatchId) { + failTest("MemorySink was probably cleared since last fetch. Use CheckAnswer instead.") + } + sink.dataSinceBatch(lastFetchedMemorySinkLastBatchId) + } else { + if (lastOnly) sink.latestBatchData else sink.allData + } + } catch { case e: Exception => failTest("Exception while getting data from sink", e) } + lastFetchedMemorySinkLastBatchId = sink.latestBatchId.getOrElse(-1L) + rows } def executeAction(action: StreamAction): Unit = { @@ -704,8 +739,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } catch { case e: Throwable => failTest(e.toString) } + + case CheckNewAnswerRows(expectedAnswer) => + val sparkAnswer = fetchStreamAnswer(currentStream, sinceLastFetchOnly = true) + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } } - pos += 1 } try { @@ -719,8 +759,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked { actns.foreach(executeAction) } + pos += 1 - case action: StreamAction => executeAction(action) + case action: StreamAction => + executeAction(action) + pos += 1 } if (streamThreadDeathCause != null) { failTest("Stream Thread Died", streamThreadDeathCause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 1cae8cb8d47f1..382da13430781 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -536,6 +536,31 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } + test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") { + // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error + // by ensuring the following. + // - A streaming query with a streaming aggregation. + // - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate. + // - Post shuffle partition has exactly 128 records (i.e. the threshold at which + // ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a + // micro-batch with 128 records that shuffle to a single partition. + // This test throws the exact error reported in SPARK-23004 without the corresponding fix. + withSQLConf("spark.sql.shuffle.partitions" -> "1") { + val input = MemoryStream[Int] + val df = input.toDF().toDF("value") + .selectExpr("value as group", "value") + .groupBy("group") + .agg(collect_list("value")) + testStream(df, outputMode = OutputMode.Update)( + AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), + AssertOnQuery { q => + q.processAllAvailable() + true + } + ) + } + } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala similarity index 80% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 0088b64d6195e..42ffd472eb843 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf -class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class StreamingDeduplicationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -97,28 +98,20 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { testStream(result, Append)( AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), - CheckLastBatch(10 to 15: _*), + CheckAnswer(10 to 15: _*), assertNumStateRows(total = 6, updated = 6), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(25), - assertNumStateRows(total = 7, updated = 1), - - AddData(inputData, 25), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0), + AddData(inputData, 25), // Advance watermark to 15 secs, no-data-batch drops rows <= 15 + CheckNewAnswer(25), + assertNumStateRows(total = 1, updated = 1), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(total = 1, updated = 0), - AddData(inputData, 45), // Advance watermark to 35 seconds - CheckLastBatch(45), - assertNumStateRows(total = 2, updated = 1), - - AddData(inputData, 45), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0) + AddData(inputData, 45), // Advance watermark to 35 seconds, no-data-batch drops row 25 + CheckNewAnswer(45), + assertNumStateRows(total = 1, updated = 1) ) } @@ -141,33 +134,20 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - // states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows) - // states in deduplicate is 10 to 15 and 25 - assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)), - - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate - // states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of - // window to evict items, so [15, 20) is still in the state store) - // states in deduplicate is 25 - assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate, emitted with no-data-batch + // states in aggregate in [15, 20) and [25, 30); no-data-batch removed [10, 14) + // states in deduplicate is 25, no-data-batch removed 10 to 14 + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(1L, 1L)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckLastBatch(), assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), AddData(inputData, 40), // Advance watermark to 30 seconds - CheckLastBatch(), - // states in aggregate in [15, 20), [25, 30) and [40, 45) - // states in deduplicate is 25 and 40, - assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)), - - AddData(inputData, 40), // Emit items less than watermark and drop their state CheckLastBatch((15 -> 1), (25 -> 1)), - // states in aggregate in [40, 45) - // states in deduplicate is 40, - assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)) + // states in aggregate is [40, 45); no-data-batch removed [15, 20) and [25, 30) + // states in deduplicate is 40; no-data-batch removed 25 + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)) ) } @@ -260,13 +240,13 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { .select($"id") testStream(df)( AddData(input, 1 -> 1, 1 -> 1, 1 -> 2), - CheckLastBatch(1, 2), + CheckAnswer(1, 2), AddData(input, 1 -> 1, 2 -> 3, 2 -> 4), - CheckLastBatch(3, 4), + CheckNewAnswer(3, 4), AddData(input, 1 -> 0, 1 -> 1, 3 -> 5, 3 -> 6), // Drop (1 -> 0, 1 -> 1) due to watermark - CheckLastBatch(5, 6), + CheckNewAnswer(5, 6), AddData(input, 1 -> 0, 4 -> 7), // Drop (1 -> 0) due to watermark - CheckLastBatch(7) + CheckNewAnswer(7) ) } @@ -279,7 +259,37 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { .select($"id", $"time".cast("long")) testStream(df)( AddData(input, 1 -> 1, 1 -> 2, 2 -> 2), - CheckLastBatch(1 -> 1, 2 -> 2) + CheckAnswer(1 -> 1, 2 -> 2) ) } + + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val result = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .select($"eventTime".cast("long").as[Long]) + + testStream(result, Append)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(10, 11, 12, 13, 14, 15), + assertNumStateRows(total = 6, updated = 6), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckNewAnswer(25), + { // State should have been cleaned if flag is set, otherwise should not have been cleaned + if (flag) assertNumStateRows(total = 1, updated = 1) + else assertNumStateRows(total = 7, updated = 1) + } + ) + } + + testWithFlag(true) + testWithFlag(false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 11bdd13942dcb..c5cc8df4356a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -62,20 +62,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input1, 1), CheckAnswer(), AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join - CheckLastBatch((1, 2, 3)), + CheckNewAnswer((1, 2, 3)), AddData(input1, 10), // 10 arrived on input2 first, then input1, should join - CheckLastBatch((10, 20, 30)), + CheckNewAnswer((10, 20, 30)), AddData(input2, 1), // another 1 in input2 should join with 1 input1 - CheckLastBatch((1, 2, 3)), + CheckNewAnswer((1, 2, 3)), StopStream, StartStream(), AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3) - CheckLastBatch((1, 2, 3), (1, 2, 3)), + CheckNewAnswer((1, 2, 3), (1, 2, 3)), StopStream, StartStream(), AddData(input1, 100), AddData(input2, 100), - CheckLastBatch((100, 200, 300)) + CheckNewAnswer((100, 200, 300)) ) } @@ -97,25 +97,25 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( AddData(input1, 1), - CheckLastBatch(), + CheckNewAnswer(), AddData(input2, 1), - CheckLastBatch((1, 10, 2, 3)), + CheckNewAnswer((1, 10, 2, 3)), StopStream, StartStream(), AddData(input1, 25), - CheckLastBatch(), + CheckNewAnswer(), StopStream, StartStream(), AddData(input2, 25), - CheckLastBatch((25, 30, 50, 75)), + CheckNewAnswer((25, 30, 50, 75)), AddData(input1, 1), - CheckLastBatch((1, 10, 2, 3)), // State for 1 still around as there is no watermark + CheckNewAnswer((1, 10, 2, 3)), // State for 1 still around as there is no watermark StopStream, StartStream(), AddData(input1, 5), - CheckLastBatch(), + CheckNewAnswer(), AddData(input2, 5), - CheckLastBatch((5, 10, 10, 15)) // No filter by any watermark + CheckNewAnswer((5, 10, 10, 15)) // No filter by any watermark ) } @@ -142,27 +142,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with assertNumStateRows(total = 1, updated = 1), AddData(input2, 1), - CheckLastBatch((1, 10, 2, 3)), + CheckAnswer((1, 10, 2, 3)), assertNumStateRows(total = 2, updated = 1), StopStream, StartStream(), AddData(input1, 25), - CheckLastBatch(), // since there is only 1 watermark operator, the watermark should be 15 - assertNumStateRows(total = 3, updated = 1), + CheckNewAnswer(), // watermark = 15, no-data-batch should remove 2 rows having window=[0,10] + assertNumStateRows(total = 1, updated = 1), AddData(input2, 25), - CheckLastBatch((25, 30, 50, 75)), // watermark = 15 should remove 2 rows having window=[0,10] + CheckNewAnswer((25, 30, 50, 75)), assertNumStateRows(total = 2, updated = 1), StopStream, StartStream(), AddData(input2, 1), - CheckLastBatch(), // Should not join as < 15 removed - assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 + CheckNewAnswer(), // Should not join as < 15 removed + assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 AddData(input1, 5), - CheckLastBatch(), // Should not join or add to state as < 15 got filtered by watermark + CheckNewAnswer(), // Same reason as above assertNumStateRows(total = 2, updated = 0) ) } @@ -189,42 +189,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 5)), CheckAnswer(), AddData(rightInput, (1, 11)), - CheckLastBatch((1, 5, 11)), + CheckNewAnswer((1, 5, 11)), AddData(rightInput, (1, 10)), - CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5 - assertNumStateRows(total = 3, updated = 1), + CheckNewAnswer(), // no match as leftTime 5 is not < rightTime 10 - 5 + assertNumStateRows(total = 3, updated = 3), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 3), (1, 30)), - CheckLastBatch((1, 3, 10), (1, 3, 11)), + CheckNewAnswer((1, 3, 10), (1, 3, 11)), assertNumStateRows(total = 5, updated = 2), AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 6, updated = 1), + CheckNewAnswer(), // event time watermark: max event time - 10 ==> 30 - 10 = 20 + // so left side going to only receive data where leftTime > 20 // right side state constraint: 20 < leftTime < rightTime - 5 ==> rightTime > 25 - - // Run another batch with event time = 25 to clear right state where rightTime <= 25 - AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 5, updated = 1), // removed (1, 11) and (1, 10), added (0, 30) + // right state where rightTime <= 25 will be cleared, (1, 11) and (1, 10) removed + assertNumStateRows(total = 4, updated = 1), // New data to right input should match with left side (1, 3) and (1, 5), as left state should // not be cleared. But rows rightTime <= 20 should be filtered due to event time watermark and // state rows with rightTime <= 25 should be removed from state. // (1, 20) ==> filtered by event time watermark = 20 // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state - // as state watermark = 25 + // as 21 < state watermark = 25 // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state AddData(rightInput, (1, 20), (1, 21), (1, 28)), - CheckLastBatch((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), - assertNumStateRows(total = 6, updated = 1), + CheckNewAnswer((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), + assertNumStateRows(total = 5, updated = 1), // New data to left input with leftTime <= 20 should be filtered due to event time watermark AddData(leftInput, (1, 20), (1, 21)), - CheckLastBatch((1, 21, 28)), - assertNumStateRows(total = 7, updated = 1) + CheckNewAnswer((1, 21, 28)), + assertNumStateRows(total = 6, updated = 1) ) } @@ -275,38 +272,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 20)), CheckAnswer(), AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), - CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), - assertNumStateRows(total = 7, updated = 6), + CheckNewAnswer((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), + assertNumStateRows(total = 7, updated = 7), // If rightTime = 60, then it matches only leftTime = [50, 65] AddData(rightInput, (1, 60)), - CheckLastBatch(), // matches with nothing on the left + CheckNewAnswer(), // matches with nothing on the left AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), - CheckLastBatch((1, 50, 60), (1, 65, 60)), - assertNumStateRows(total = 12, updated = 4), + CheckNewAnswer((1, 50, 60), (1, 65, 60)), // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) // Should drop < 20 from left, i.e., none // Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=) // Should drop < 25 from the right, i.e., 14 and 15 - AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to stat - CheckLastBatch((1, 31, 26), (1, 31, 30), (1, 31, 31)), - assertNumStateRows(total = 11, updated = 1), // 12 - 2 removed + 1 added + assertNumStateRows(total = 10, updated = 5), // 12 - 2 removed + + AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to state + CheckNewAnswer((1, 31, 26), (1, 31, 30), (1, 31, 31)), + assertNumStateRows(total = 11, updated = 1), // only 31 added // Advance the watermark AddData(rightInput, (1, 80)), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 1), - + CheckNewAnswer(), // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46 // Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=) // Should drop < 36 from left, i.e., 20, 31 (30 was not added) // Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=) // Should drop < 41 from the right, i.e., 25, 26, 30, 31 - AddData(rightInput, (1, 50)), - CheckLastBatch((1, 49, 50), (1, 50, 50)), - assertNumStateRows(total = 7, updated = 1) // 12 - 6 removed + 1 added + assertNumStateRows(total = 6, updated = 1), // 12 - 6 removed + + AddData(rightInput, (1, 46), (1, 50)), // 46 should not be processed or added to state + CheckNewAnswer((1, 49, 50), (1, 50, 50)), + assertNumStateRows(total = 7, updated = 1) // 50 added ) } @@ -322,7 +320,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with input1.addData(1) q.awaitTermination(10000) } - assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) + assert(e.toString.contains("Stream-stream join without equality predicate is not supported")) } test("stream stream self join") { @@ -404,10 +402,25 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input1, 1, 5), AddData(input2, 1, 5, 10), AddData(input3, 5, 10), - CheckLastBatch((5, 10, 5, 15, 5, 25))) + CheckNewAnswer((5, 10, 5, 15, 5, 25))) + } + + test("streaming join should require HashClusteredDistribution from children") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b) + val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b) + val joined = df1.join(df2, Seq("a", "b")).select('a) + + testStream(joined)( + AddData(input1, 1.to(1000): _*), + AddData(input2, 1.to(1000): _*), + CheckAnswer(1.to(1000): _*)) } } + class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { import testImplicits._ @@ -465,13 +478,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with leftValue <= 4 should generate their outer join row now and // not get added to the state. - CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), + CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + CheckNewAnswer(), AddData(rightInput, 20), - CheckLastBatch((20, 30, 40, "60")) + CheckNewAnswer((20, 30, 40, "60")) ) } @@ -492,15 +505,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), - // The right rows with value <= 7 should never be added to the state. - CheckLastBatch(Row(3, 10, 6, "9")), + // The right rows with rightValue <= 7 should never be added to the state. + CheckNewAnswer(Row(3, 10, 6, "9")), // rightValue = 9 > 7 hence joined and added to state assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls + CheckNewAnswer(Row(4, 10, 8, null), Row(5, 10, 10, null)), AddData(rightInput, 20), - CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) + CheckNewAnswer(Row(20, 30, 40, "60")) ) } @@ -521,15 +534,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), - // The left rows with value <= 4 should never be added to the state. - CheckLastBatch(Row(3, 10, 6, "9")), + // The left rows with leftValue <= 4 should never be added to the state. + CheckNewAnswer(Row(3, 10, 6, "9")), // leftValue = 7 > 4 hence joined and added to state assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls + CheckNewAnswer(Row(4, 10, null, "12"), Row(5, 10, null, "15")), AddData(rightInput, 20), - CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) + CheckNewAnswer(Row(20, 30, 40, "60")) ) } @@ -552,13 +565,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with rightValue <= 7 should generate their outer join row now and // not get added to the state. - CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), + CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + CheckNewAnswer(), AddData(rightInput, 20), - CheckLastBatch((20, 30, 40, "60")) + CheckNewAnswer((20, 30, 40, "60")) ) } @@ -568,14 +581,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), - // Old state doesn't get dropped until the batch *after* it gets introduced, so the - // nulls won't show up until the next batch after the watermark advances. - MultiAddData(leftInput, 21)(rightInput, 22), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + + MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), - CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), + CheckNewAnswer(Row(22, 30, 44, 66)), assertNumStateRows(total = 3, updated = 1) ) } @@ -586,14 +599,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), - // Old state doesn't get dropped until the batch *after* it gets introduced, so the - // nulls won't show up until the next batch after the watermark advances. - MultiAddData(leftInput, 21)(rightInput, 22), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + + MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer(Row(6, 10, null, 18), Row(7, 10, null, 21)), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), - CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), + CheckNewAnswer(Row(22, 30, 44, 66)), assertNumStateRows(total = 3, updated = 1) ) } @@ -627,21 +640,18 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 5), (3, 5)), CheckAnswer(), AddData(rightInput, (1, 10), (2, 5)), - CheckLastBatch((1, 1, 5, 10)), + CheckNewAnswer((1, 1, 5, 10)), AddData(rightInput, (1, 11)), - CheckLastBatch(), // no match as left time is too low - assertNumStateRows(total = 5, updated = 1), + CheckNewAnswer(), // no match as left time is too low + assertNumStateRows(total = 5, updated = 5), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 7), (1, 30)), - CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)), + CheckNewAnswer((1, 1, 7, 10), (1, 1, 7, 11)), assertNumStateRows(total = 7, updated = 2), - AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 8, updated = 1), - AddData(rightInput, (0, 30)), - CheckLastBatch(outerResult), - assertNumStateRows(total = 3, updated = 1) + AddData(rightInput, (0, 30)), // watermark = 30 - 10 = 20, no-data-batch computes nulls + CheckNewAnswer(outerResult), + assertNumStateRows(total = 2, updated = 1) ) } } @@ -665,36 +675,41 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3), - CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), - assertNumStateRows(total = 5, updated = 2), + CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + assertNumStateRows(total = 3, updated = 3), // only right 1, 2, 3 added + + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch cleared < 10 + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 2), // only 20 and 21 left in state + AddData(rightInput, 20), - CheckLastBatch( - Row(20, 30, 40, 60)), + CheckNewAnswer(Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), + // leftValue and rightValue both satisfying condition should not generate outer join rows - MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), - CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), - MultiAddData(leftInput, 70)(rightInput, 71), - CheckLastBatch(), - assertNumStateRows(total = 6, updated = 2), + MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), // watermark = 31 + CheckNewAnswer((40, 50, 80, 120), (41, 50, 82, 123)), + assertNumStateRows(total = 4, updated = 4), // only left 40, 41 + right 40,41 left in state + + MultiAddData(leftInput, 70)(rightInput, 71), // watermark = 60 + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 2), // only 70, 71 left in state + AddData(rightInput, 70), - CheckLastBatch((70, 80, 140, 210)), + CheckNewAnswer((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), + // rightValue between 300 and 1000 should generate outer join rows even though it matches left - MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), - CheckLastBatch(), + MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), // watermark = 91 + CheckNewAnswer(), + assertNumStateRows(total = 6, updated = 3), // only 101 - 103 left in state + MultiAddData(leftInput, 1000)(rightInput, 1001), - CheckLastBatch(), - assertNumStateRows(total = 8, updated = 2), - AddData(rightInput, 1000), - CheckLastBatch( - Row(1000, 1010, 2000, 3000), + CheckNewAnswer( Row(101, 110, 202, null), Row(102, 110, 204, null), Row(103, 110, 206, null)), - assertNumStateRows(total = 3, updated = 1) + assertNumStateRows(total = 2, updated = 2) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala new file mode 100644 index 0000000000000..1aaf8a9aa2d55 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.language.reflectiveCalls + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.StreamingQueryListener._ + + +class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + + override protected def sparkConf: SparkConf = + super.sparkConf.set("spark.sql.streaming.streamingQueryListeners", + "org.apache.spark.sql.streaming.TestListener") + + test("test if the configured query lister is loaded") { + testStream(MemoryStream[Int].toDS)( + StartStream(), + StopStream + ) + + assert(TestListener.queryStartedEvent != null) + assert(TestListener.queryTerminatedEvent != null) + } + +} + +object TestListener { + @volatile var queryStartedEvent: QueryStartedEvent = null + @volatile var queryTerminatedEvent: QueryTerminatedEvent = null +} + +class TestListener(sparkConf: SparkConf) extends StreamingQueryListener { + + override def onQueryStarted(event: QueryStartedEvent): Unit = { + TestListener.queryStartedEvent = event + } + + override def onQueryProgress(event: QueryProgressEvent): Unit = {} + + override def onQueryTerminated(event: QueryTerminatedEvent): Unit = { + TestListener.queryTerminatedEvent = event + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 20942ed93897c..936a076d647b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -227,10 +227,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } // getBatch should take 100 ms the first time it is called - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { synchronized { clock.waitTillTime(1350) - super.createUnsafeRowReaderFactories() + super.planUnsafeInputPartitions() } } } @@ -290,13 +290,14 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AdvanceManualClock(100), // time = 1150 to unblock getEndOffset AssertClockTime(1150), - AssertStreamExecThreadIsWaitingForTime(1350), // will block on createReadTasks that needs 1350 + // will block on planInputPartitions that needs 1350 + AssertStreamExecThreadIsWaitingForTime(1350), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(200), // time = 1350 to unblock createReadTasks + AdvanceManualClock(200), // time = 1350 to unblock planInputPartitions AssertClockTime(1350), AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), @@ -334,8 +335,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") - assert(progress.sources(0).startOffset === null) - assert(progress.sources(0).endOffset !== null) + assert(progress.sources(0).startOffset === null) // no prior offset + assert(progress.sources(0).endOffset === "0") assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms assert(progress.stateOperators.length === 1) @@ -361,6 +362,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(query.lastProgress.batchId === 1) assert(query.lastProgress.inputRowsPerSecond === 2.0) assert(query.lastProgress.sources(0).inputRowsPerSecond === 2.0) + assert(query.lastProgress.sources(0).startOffset === "0") + assert(query.lastProgress.sources(0).endOffset === "1") true }, @@ -466,7 +469,17 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } - test("input row calculation with mixed batch and streaming sources") { + test("input row calculation with same V1 source used twice in self-join") { + val streamingTriggerDF = spark.createDataset(1 to 10).toDF + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") + + val progress = getFirstProgress(streamingInputDF.join(streamingInputDF, "value")) + assert(progress.numInputRows === 20) // data is read multiple times in self-joins + assert(progress.sources.size === 1) + assert(progress.sources(0).numInputRows === 20) + } + + test("input row calculation with mixed batch and streaming V1 sources") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") @@ -479,7 +492,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } - test("input row calculation with trigger input DF having multiple leaves") { + test("input row calculation with trigger input DF having multiple leaves in V1 source") { val streamingTriggerDF = spark.createDataset(1 to 5).toDF.union(spark.createDataset(6 to 10).toDF) require(streamingTriggerDF.logicalPlan.collectLeaves().size > 1) @@ -492,6 +505,121 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } + test("input row calculation with same V2 source used twice in self-union") { + val streamInput = MemoryStream[Int] + + testStream(streamInput.toDF().union(streamInput.toDF()), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 1, 2, 2, 3, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with same V2 source used twice in self-join") { + val streamInput = MemoryStream[Int] + val df = streamInput.toDF() + testStream(df.join(df, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with trigger having data for only one of two V2 sources") { + val streamInput1 = MemoryStream[Int] + val streamInput2 = MemoryStream[Int] + + testStream(streamInput1.toDF().union(streamInput2.toDF()), useV2Sink = true)( + AddData(streamInput1, 1, 2, 3), + CheckLastBatch(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 3) + assert(lastProgress.get.sources(1).numInputRows == 0) + true + }, + AddData(streamInput2, 4, 5), + CheckLastBatch(4, 5), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 2) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 0) + assert(lastProgress.get.sources(1).numInputRows == 2) + true + } + ) + } + + test("input row calculation with mixed batch and streaming V2 sources") { + + val streamInput = MemoryStream[Int] + val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") + + testStream(streamInput.toDF().join(staticInputDF, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + + // The number of leaves in the trigger's logical plan should be same as the executed plan. + require( + q.lastExecution.logical.collectLeaves().length == + q.lastExecution.executedPlan.collectLeaves().length) + + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + + val streamInput2 = MemoryStream[Int] + val staticInputDF2 = staticInputDF.union(staticInputDF).cache() + + testStream(streamInput2.toDF().join(staticInputDF2, "value"), useV2Sink = true)( + AddData(streamInput2, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + // The number of leaves in the trigger's logical plan should be different from + // the executed plan. The static input will have two leaves in the logical plan + // (due to the union), but will be converted to a single leaf in the executed plan + // (due to the caching, the cached subplan is replaced by a single InMemoryTableScanExec). + require( + q.lastExecution.logical.collectLeaves().length != + q.lastExecution.executedPlan.collectLeaves().length) + + // Despite the mismatch in total number of leaves in the logical and executed plans, + // we should be able to attribute streaming input metrics to the streaming sources. + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + } + testQuietly("StreamExecution metadata garbage collection") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) @@ -706,6 +834,21 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi CheckLastBatch(("A", 1))) } + test("StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation.toJSON " + + "should not fail") { + val df = spark.readStream.format("rate").load() + assert(df.logicalPlan.toJSON.contains("StreamingRelationV2")) + + testStream(df)( + AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingExecutionRelation")) + ) + + testStream(df, useV2Sink = true)( + StartStream(trigger = Trigger.Continuous(100)), + AssertOnQuery(_.logicalPlan.toJSON.contains("ContinuousExecutionRelation")) + ) + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) @@ -733,6 +876,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + /** Returns the last query progress from query.recentProgress where numInputRows is positive */ + def getLastProgressWithData(q: StreamingQuery): Option[StreamingQueryProgress] = { + q.recentProgress.filter(_.numInputRows > 0).lastOption + } + /** * A [[StreamAction]] to test the behavior of `StreamingQuery.awaitTermination()`. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala new file mode 100644 index 0000000000000..0223812600961 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.OutputMode + +class ContinuousAggregationSuite extends ContinuousSuiteBase { + import testImplicits._ + + test("not enabled") { + val ex = intercept[AnalysisException] { + val input = ContinuousMemoryStream.singlePartition[Int] + testStream(input.toDF().agg(max('value)), OutputMode.Complete)() + } + + assert(ex.getMessage.contains( + "In continuous processing mode, coalesce(1) must be called before aggregate operation")) + } + + test("basic") { + withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + val input = ContinuousMemoryStream.singlePartition[Int] + + testStream(input.toDF().agg(max('value)), OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + } + + test("multiple partitions with coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF().coalesce(1).agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with coalesce - multiple transformations") { + val input = ContinuousMemoryStream[Int] + + // We use a barrier to make sure predicates both before and after coalesce work + val df = input.toDF() + .select('value as 'copy, 'value) + .where('copy =!= 1) + .planWithBarrier + .coalesce(1) + .where('copy =!= 2) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(0), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with multiple coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF() + .coalesce(1) + .planWithBarrier + .coalesce(1) + .select('value as 'copy, 'value) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("repeated restart") { + withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + val input = ContinuousMemoryStream.singlePartition[Int] + + testStream(input.toDF().agg(max('value)), OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + StartStream(), + StopStream, + StartStream(), + StopStream, + StartStream(), + AddData(input, 0), + CheckAnswer(2), + AddData(input, 5), + CheckAnswer(5)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala new file mode 100644 index 0000000000000..0e7e6febb53df --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} + +import org.mockito.{ArgumentCaptor, Matchers} +import org.mockito.Mockito._ +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { + case class LongPartitionOffset(offset: Long) extends PartitionOffset + + val coordinatorId = s"${getClass.getSimpleName}-epochCoordinatorIdForUnitTest" + val startEpoch = 0 + + var epochEndpoint: RpcEndpointRef = _ + + override def beforeEach(): Unit = { + super.beforeEach() + epochEndpoint = EpochCoordinatorRef.create( + mock[StreamWriter], + mock[ContinuousReader], + mock[ContinuousExecution], + coordinatorId, + startEpoch, + spark, + SparkEnv.get) + EpochTracker.initializeCurrentEpoch(0) + } + + override def afterEach(): Unit = { + SparkEnv.get.rpcEnv.stop(epochEndpoint) + epochEndpoint = null + super.afterEach() + } + + + private val mockContext = mock[TaskContext] + when(mockContext.getLocalProperty(ContinuousExecution.START_EPOCH_KEY)) + .thenReturn(startEpoch.toString) + when(mockContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)) + .thenReturn(coordinatorId) + + /** + * Set up a ContinuousQueuedDataReader for testing. The blocking queue can be used to send + * rows to the wrapped data reader. + */ + private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { + val queue = new ArrayBlockingQueue[UnsafeRow](1024) + val factory = new InputPartition[UnsafeRow] { + override def createPartitionReader() = new ContinuousInputPartitionReader[UnsafeRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } + + override def get = curr + + override def getOffset = LongPartitionOffset(index) + + override def close() = {} + } + } + val reader = new ContinuousQueuedDataReader( + new ContinuousDataSourceRDDPartition(0, factory), + mockContext, + dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, + epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) + + (queue, reader) + } + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + test("basic data read") { + val (input, reader) = setup() + + input.add(unsafeRow(12345)) + assert(reader.next().getInt(0) == 12345) + } + + test("basic epoch marker") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } + + test("new rows after markers") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + } + + test("new markers after rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + } + + test("alternating markers and rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + assert(reader.next().getInt(0) == 11111) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + input.add(unsafeRow(33333)) + assert(reader.next().getInt(0) == 33333) + input.add(unsafeRow(44444)) + assert(reader.next().getInt(0) == 44444) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index c318b951ff992..4980b0cd41f81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 @@ -75,73 +75,50 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("map") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .map(r => r.getLong(0) * 2) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().map(_.getInt(0) * 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 40, 2).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(0, 2), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(0, 2, 4, 6, 8)) } test("flatMap") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .flatMap(r => Seq(0, r.getLong(0), r.getLong(0) * 2)) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().flatMap(r => Seq(0, r.getInt(0), r.getInt(0) * 2)) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).flatMap(n => Seq(0, n, n * 2)).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer((0 to 1).flatMap(n => Seq(0, n, n * 2)): _*), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer((0 to 4).flatMap(n => Seq(0, n, n * 2)): _*)) } test("filter") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .where('value > 5) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().where('value > 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("deduplicate") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .dropDuplicates() + val input = ContinuousMemoryStream[Int] + val df = input.toDF().dropDuplicates() val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -149,15 +126,11 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("timestamp") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select(current_timestamp()) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().select(current_timestamp()) val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -165,58 +138,43 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("subquery alias") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .createOrReplaceTempView("rate") - val test = spark.sql("select value from rate where value > 5") + val input = ContinuousMemoryStream[Int] + input.toDF().createOrReplaceTempView("memory") + val test = spark.sql("select value from memory where value > 2") - testStream(test, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(test)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("repeatedly restart") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + testStream(df)( + StartStream(), + AddData(input, 0, 1), + CheckAnswer(0, 1), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), - AwaitEpoch(2), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), + StartStream(), + StopStream, + AddData(input, 2, 3), + StartStream(), + CheckAnswer(0, 1, 2, 3), StopStream) } test("task failure kills the query") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() // Get an arbitrary task from this query to kill. It doesn't matter which one. var taskId: Long = -1 @@ -227,9 +185,9 @@ class ContinuousSuite extends ContinuousSuiteBase { } spark.sparkContext.addSparkListener(listener) try { - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(Trigger.Continuous(100)), - Execute(waitForRateSourceTriggers(_, 2)), + AddData(input, 0, 1, 2, 3), Execute { _ => // Wait until a task is started, then kill its first attempt. eventually(timeout(streamingTimeout)) { @@ -252,6 +210,7 @@ class ContinuousSuite extends ContinuousSuiteBase { .option("rowsPerSecond", "2") .load() .select('value) + val query = df.writeStream .format("memory") .queryName("noharness") @@ -338,3 +297,49 @@ class ContinuousStressSuite extends ContinuousSuiteBase { CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } } + +class ContinuousMetaSuite extends ContinuousSuiteBase { + import testImplicits._ + + // We need to specify spark.sql.streaming.minBatchesToRetain to do the following test. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true") + .set("spark.sql.streaming.minBatchesToRetain", "2"))) + + test("SPARK-24351: check offsetLog/commitLog retained in the checkpoint directory") { + withTempDir { checkpointDir => + val input = ContinuousMemoryStream[Int] + val df = input.toDF().mapPartitions(iter => { + // Sleep the task thread for 300 ms to make sure epoch processing time 3 times + // longer than epoch creating interval. So the gap between last committed + // epoch and currentBatchId grows over time. + Thread.sleep(300) + iter.map(row => row.getInt(0) * 2) + }) + + testStream(df)( + StartStream(trigger = Trigger.Continuous(100), + checkpointLocation = checkpointDir.getAbsolutePath), + AddData(input, 1), + CheckAnswer(2), + // Make sure epoch 2 has been committed before the following validation. + AwaitEpoch(2), + StopStream, + AssertOnQuery(q => { + q.commitLog.getLatest() match { + case Some((latestEpochId, _)) => + val commitLogValidateResult = q.commitLog.get(latestEpochId - 1).isDefined && + q.commitLog.get(latestEpochId - 2).isEmpty + val offsetLogValidateResult = q.offsetLog.get(latestEpochId - 1).isDefined && + q.offsetLog.get(latestEpochId - 2).isEmpty + commitLogValidateResult && offsetLogValidateResult + case None => false + } + }) + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 99e30561f81d5..82836dced9df7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -120,7 +120,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2)) } - ignore("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { + test("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { setWriterPartitions(2) setReaderPartitions(2) @@ -141,7 +141,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2)) } - ignore("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { + test("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { setWriterPartitions(1) setReaderPartitions(1) @@ -162,7 +162,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2, 3, 4)) } - ignore("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { + test("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { setWriterPartitions(1) setReaderPartitions(1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala new file mode 100644 index 0000000000000..f84f3d49707bf --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.UUID + +import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class ContinuousShuffleSuite extends StreamTest { + // In this unit test, we emulate that we're in the task thread where + // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context + // thread local to be set. + var ctx: TaskContextImpl = _ + + override def beforeEach(): Unit = { + super.beforeEach() + ctx = TaskContext.empty() + TaskContext.setTaskContext(ctx) + } + + override def afterEach(): Unit = { + ctx.markTaskCompleted(None) + TaskContext.unset() + ctx = null + super.afterEach() + } + + private implicit def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + private def unsafeRow(value: String) = { + UnsafeProjection.create(Array(StringType : DataType))( + new GenericInternalRow(Array(UTF8String.fromString(value): Any))) + } + + private def send(endpoint: RpcEndpointRef, messages: RPCContinuousShuffleMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } + + private def readRDDEndpoint(rdd: ContinuousShuffleReadRDD) = { + rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + } + + private def readEpoch(rdd: ContinuousShuffleReadRDD) = { + rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0)) + } + + test("reader - one epoch") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverRow(0, unsafeRow(222)), + ReceiverRow(0, unsafeRow(333)), + ReceiverEpochMarker(0) + ) + + val iter = rdd.compute(rdd.partitions(0), ctx) + assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) + } + + test("reader - multiple epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(222)), + ReceiverRow(0, unsafeRow(333)), + ReceiverEpochMarker(0) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) + } + + test("reader - empty epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + + send( + endpoint, + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0) + ) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + } + + test("reader - multiple partitions") { + val rdd = new ContinuousShuffleReadRDD( + sparkContext, + numPartitions = 5, + endpointNames = Seq.fill(5)(s"endpt-${UUID.randomUUID()}")) + // Send all data before processing to ensure there's no crossover. + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + // Send index for identification. + send( + part.endpoint, + ReceiverRow(0, unsafeRow(part.index)), + ReceiverEpochMarker(0) + ) + } + + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + val iter = rdd.compute(part, ctx) + assert(iter.next().getInt(0) == part.index) + assert(!iter.hasNext) + } + } + + test("reader - blocks waiting for new rows") { + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) + val epoch = rdd.compute(rdd.partitions(0), ctx) + + val readRowThread = new Thread { + override def run(): Unit = { + try { + epoch.next().getInt(0) + } catch { + case _: InterruptedException => // do nothing - expected at test ending + } + } + } + + try { + readRowThread.start() + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.TIMED_WAITING) + } + } finally { + readRowThread.interrupt() + readRowThread.join() + } + } + + test("reader - multiple writers") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(1), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + } + + test("reader - epoch only ends when all writers send markers") { + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(2) + ) + + val epoch = rdd.compute(rdd.partitions(0), ctx) + val rows = (0 until 3).map(_ => epoch.next()).toSet + assert(rows.map(_.getUTF8String(0).toString) == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + + // After checking the right rows, block until we get an epoch marker indicating there's no next. + // (Also fail the assertion if for some reason we get a row.) + + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!epoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + // Send the last epoch marker - now the epoch should finish. + send(endpoint, ReceiverEpochMarker(1)) + eventually(timeout(streamingTimeout)) { + !readEpochMarkerThread.isAlive + } + + // Join to pick up assertion failures. + readEpochMarkerThread.join(streamingTimeout.toMillis) + } + + test("reader - writer epochs non aligned") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should + // collate them as though the markers were aligned in the first place. + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow("writer0-row1")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row1")), + ReceiverEpochMarker(1), + + ReceiverEpochMarker(2), + ReceiverEpochMarker(2), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(firstEpoch == Set("writer0-row0")) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(secondEpoch == Set("writer0-row1", "writer1-row0")) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) + } + + test("one epoch") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + } + + test("multiple epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + writer.write(Iterator(4, 5, 6)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + assert(readEpoch(reader) == Seq(4, 5, 6)) + } + + test("empty epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator()) + writer.write(Iterator(1, 2)) + writer.write(Iterator()) + writer.write(Iterator()) + writer.write(Iterator(3, 4)) + writer.write(Iterator()) + + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(1, 2)) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(3, 4)) + assert(readEpoch(reader) == Seq()) + } + + test("blocks waiting for writer") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readRowThread = new Thread { + override def run(): Unit = { + assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1)) + } + } + readRowThread.start() + + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.TIMED_WAITING) + } + + // Once we write the epoch the thread should stop waiting and succeed. + writer.write(Iterator(1)) + readRowThread.join(streamingTimeout.toMillis) + } + + test("multiple writer partitions") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(0).write(Iterator(1, 4, 7)) + writers(1).write(Iterator(2, 5)) + writers(2).write(Iterator(3, 6)) + + writers(0).write(Iterator(4, 7, 10)) + writers(1).write(Iterator(5, 8)) + writers(2).write(Iterator(6, 9)) + + // Since there are multiple asynchronous writers, the original row sequencing is not guaranteed. + // The epochs should be deterministically preserved, however. + assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet) + assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet) + } + + test("reader epoch only ends when all writer partitions write it") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(1).write(Iterator()) + writers(2).write(Iterator()) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!readerEpoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + writers(0).write(Iterator()) + readEpochMarkerThread.join(streamingTimeout.toMillis) + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index af4618bed5456..c1a28b9bc75ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} @@ -44,7 +44,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def setStartOffset(start: Optional[Offset]): Unit = {} - def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { + def planInputPartitions(): java.util.ArrayList[InputPartition[Row]] = { throw new IllegalStateException("fake source - cannot actually read") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 14b1feb2adc20..b65058fffd339 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -276,7 +276,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be assert(LastOptions.parameters("doubleOpt") == "6.7") } - test("check jdbc() does not support partitioning or bucketing") { + test("check jdbc() does not support partitioning, bucketBy or sortBy") { val df = spark.read.text(Utils.createTempDir(namePrefix = "text").getCanonicalPath) var w = df.write.partitionBy("value") @@ -287,7 +287,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be w = df.write.bucketBy(2, "value") e = intercept[AnalysisException](w.jdbc(null, null, null)) - Seq("jdbc", "bucketing").foreach { s => + Seq("jdbc", "does not support bucketBy right now").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + + w = df.write.sortBy("value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("sortBy must be used together with bucketBy").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + + w = df.write.bucketBy(2, "value").sortBy("value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("jdbc", "does not support bucketBy and sortBy right now").foreach { s => assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index c5ade65283045..10000f12ab329 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -18,6 +18,8 @@ package org.apache.hive.service.auth; import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.util.ArrayList; @@ -26,6 +28,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import javax.net.ssl.SSLServerSocket; import javax.security.auth.login.LoginException; @@ -92,7 +95,30 @@ public String getAuthName() { public static final String HS2_PROXY_USER = "hive.server2.proxy.user"; public static final String HS2_CLIENT_TOKEN = "hiveserver2ClientToken"; - public HiveAuthFactory(HiveConf conf) throws TTransportException { + private static Field keytabFile = null; + private static Method getKeytab = null; + static { + Class clz = UserGroupInformation.class; + try { + keytabFile = clz.getDeclaredField("keytabFile"); + keytabFile.setAccessible(true); + } catch (NoSuchFieldException nfe) { + LOG.debug("Cannot find private field \"keytabFile\" in class: " + + UserGroupInformation.class.getCanonicalName(), nfe); + keytabFile = null; + } + + try { + getKeytab = clz.getDeclaredMethod("getKeytab"); + getKeytab.setAccessible(true); + } catch(NoSuchMethodException nme) { + LOG.debug("Cannot find private method \"getKeytab\" in class:" + + UserGroupInformation.class.getCanonicalName(), nme); + getKeytab = null; + } + } + + public HiveAuthFactory(HiveConf conf) throws TTransportException, IOException { this.conf = conf; transportMode = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE); authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION); @@ -107,9 +133,16 @@ public HiveAuthFactory(HiveConf conf) throws TTransportException { authTypeStr = AuthTypes.NONE.getAuthName(); } if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { - saslServer = ShimLoader.getHadoopThriftAuthBridge() - .createServer(conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), - conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL)); + String principal = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL); + String keytab = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB); + if (needUgiLogin(UserGroupInformation.getCurrentUser(), + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keytab)) { + saslServer = ShimLoader.getHadoopThriftAuthBridge().createServer(principal, keytab); + } else { + // Using the default constructor to avoid unnecessary UGI login. + saslServer = new HadoopThriftAuthBridge.Server(); + } + // start delegation token manager try { // rawStore is only necessary for DBTokenStore @@ -362,4 +395,25 @@ public static void verifyProxyAccess(String realUser, String proxyUser, String i } } + public static boolean needUgiLogin(UserGroupInformation ugi, String principal, String keytab) { + return null == ugi || !ugi.hasKerberosCredentials() || !ugi.getUserName().equals(principal) || + !Objects.equals(keytab, getKeytabFromUgi()); + } + + private static String getKeytabFromUgi() { + synchronized (UserGroupInformation.class) { + try { + if (keytabFile != null) { + return (String) keytabFile.get(null); + } else if (getKeytab != null) { + return (String) getKeytab.invoke(UserGroupInformation.getCurrentUser()); + } else { + return null; + } + } catch (Exception e) { + LOG.debug("Fail to get keytabFile path via reflection", e); + return null; + } + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 084f8200102ba..d9fd3ebd3c65d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.log4j.{Level, Logger} +import org.apache.log4j.Level import org.apache.thrift.transport.TSocket import org.apache.spark.SparkConf @@ -300,10 +300,6 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) - if (sessionState.getIsSilent) { - Logger.getRootLogger.setLevel(Level.WARN) - } - private val isRemoteMode = { SparkSQLCLIDriver.isRemoteMode(sessionState) } @@ -315,6 +311,9 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { // because the Hive unit tests do not go through the main() code path. if (!isRemoteMode) { SparkSQLEnv.init() + if (sessionState.getIsSilent) { + SparkSQLEnv.sparkContext.setLogLevel(Level.WARN.toString) + } } else { // Hive 1.2 + not supported in CLI throw new RuntimeException("Remote operations not supported") diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index ad1f5eb9ca3a7..1335e16e35882 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -27,7 +27,7 @@ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.shims.Utils -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{SecurityUtil, UserGroupInformation} import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory @@ -52,8 +52,22 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC if (UserGroupInformation.isSecurityEnabled) { try { - HiveAuthFactory.loginFromKeytab(hiveConf) - sparkServiceUGI = Utils.getUGI() + val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL) + val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB) + if (principal.isEmpty || keyTabFile.isEmpty) { + throw new IOException( + "HiveServer2 Kerberos principal or keytab is not correctly configured") + } + + val originalUgi = UserGroupInformation.getCurrentUser + sparkServiceUGI = if (HiveAuthFactory.needUgiLogin(originalUgi, + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keyTabFile)) { + HiveAuthFactory.loginFromKeytab(hiveConf) + Utils.getUGI() + } else { + originalUgi + } + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index cbd75ad12d430..8980bcf885589 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -50,7 +50,7 @@ private[hive] object SparkSQLEnv extends Logging { sqlContext = sparkSession.sqlContext val metadataHive = sparkSession - .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + .sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index f517bffccdf31..771104ceb8842 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -47,10 +47,10 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {listener.getOnlineSessionNum} session(s) are online, running {listener.getTotalRunning} SQL statement(s) ++ - generateSessionStatsTable() ++ - generateSQLStatsTable() + generateSessionStatsTable(request) ++ + generateSQLStatsTable(request) } - UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Server", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -67,7 +67,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } /** Generate stats of batch statements of the thrift server program */ - private def generateSQLStatsTable(): Seq[Node] = { + private def generateSQLStatsTable(request: HttpServletRequest): Seq[Node] = { val numStatement = listener.getExecutionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", @@ -76,7 +76,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - + [{id}] } @@ -138,7 +139,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } /** Generate stats of batch sessions of the thrift server program */ - private def generateSessionStatsTable(): Seq[Node] = { + private def generateSessionStatsTable(request: HttpServletRequest): Seq[Node] = { val sessionList = listener.getSessionList val numBatches = sessionList.size val table = if (numBatches > 0) { @@ -146,8 +147,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/%s/session?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), parent.prefix, session.sessionId) + val sessionLink = "%s/%s/session/?id=%s".format( + UIUtils.prependBaseUri(request, parent.basePath), parent.prefix, session.sessionId) {session.userName} {session.ip} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 5cd2fdf6437c2..163eb43aabc72 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -56,9 +56,9 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) Session created at {formatDate(sessionStat.startTimestamp)}, Total run {sessionStat.totalExecution} SQL ++ - generateSQLStatsTable(sessionStat.sessionId) + generateSQLStatsTable(request, sessionStat.sessionId) } - UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Session", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -75,7 +75,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) } /** Generate stats of batch statements of the thrift server program */ - private def generateSQLStatsTable(sessionID: String): Seq[Node] = { + private def generateSQLStatsTable(request: HttpServletRequest, sessionID: String): Seq[Node] = { val executionList = listener.getExecutionList .filter(_.sessionId == sessionID) val numStatement = executionList.size @@ -86,7 +86,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - + [{id}] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 28c340a176d91..011a3ba553cb2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -158,13 +158,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { @@ -177,7 +177,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * * Note: As of now, this only supports altering database properties! */ - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { val existingDb = getDatabase(dbDefinition.name) if (existingDb.properties == dbDefinition.properties) { logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + @@ -211,7 +211,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) @@ -480,7 +480,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -489,7 +489,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropTable(db, table, ignoreIfNotExists, purge) } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = withClient { @@ -540,7 +540,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. */ - override def doAlterTable(tableDefinition: CatalogTable): Unit = withClient { + override def alterTable(tableDefinition: CatalogTable): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -624,7 +624,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * data schema should not have conflict column names with the existing partition columns, and * should still contain all the existing data columns. */ - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = withClient { @@ -656,7 +656,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = withClient { @@ -1208,7 +1208,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction( + override def createFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) @@ -1221,12 +1221,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doDropFunction(db: String, name: String): Unit = withClient { + override def dropFunction(db: String, name: String): Unit = withClient { requireFunctionExists(db, name) client.dropFunction(db, name) } - override protected def doAlterFunction( + override def alterFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) @@ -1235,7 +1235,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.alterFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = withClient { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index e5aff3b99d0b9..de41bb418181d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, Gener import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, ExternalCatalog, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType} private[sql] class HiveSessionCatalog( - externalCatalogBuilder: () => HiveExternalCatalog, + externalCatalogBuilder: () => ExternalCatalog, globalTempViewManagerBuilder: () => GlobalTempViewManager, val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, @@ -175,6 +175,10 @@ private[sql] class HiveSessionCatalog( super.functionExists(name) || hiveFunctions.contains(name.funcName) } + override def isPersistentFunction(name: FunctionIdentifier): Boolean = { + super.isPersistentFunction(name) || hiveFunctions.contains(name.funcName) + } + /** List of functions we pass over to Hive. Note that over time this list should go to 0. */ // We have a list of Hive built-in functions that we do not support. So, we will check // Hive's function registry and lazily load needed functions into our own function registry. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 40b9bb51ca9a0..2882672f327c4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner @@ -35,14 +36,14 @@ import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLo class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) extends BaseSessionStateBuilder(session, parentState) { - private def externalCatalog: HiveExternalCatalog = - session.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private def externalCatalog: ExternalCatalogWithListener = session.sharedState.externalCatalog /** * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - new HiveSessionResourceLoader(session, () => externalCatalog.client) + new HiveSessionResourceLoader( + session, () => externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client) } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8df05cbb20361..a0c197b06ddab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -186,15 +186,28 @@ case class RelationConversions( serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } + // Return true for Apache ORC and Hive ORC-related configuration names. + // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`. + private def isOrcProperty(key: String) = + key.startsWith("orc.") || key.contains(".orc.") + + private def isParquetProperty(key: String) = + key.startsWith("parquet.") || key.contains(".parquet.") + private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + + // Consider table and storage properties. For properties existing in both sides, storage + // properties will supersede table properties. if (serde.contains("parquet")) { - val options = relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> + val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++ + relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { - val options = relation.tableMeta.storage.properties + val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++ + relation.tableMeta.storage.properties if (conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { sessionCatalog.metastoreCatalog.convertToLogicalRelation( relation, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 10c9603745379..cd321d41f43e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.3.2.") + s"0.12.0 through 2.3.3.") .stringConf .createWithDefault(builtinHiveVersion) @@ -105,11 +105,10 @@ private[spark] object HiveUtils extends Logging { .createWithDefault(false) val CONVERT_METASTORE_ORC = buildConf("spark.sql.hive.convertMetastoreOrc") - .internal() .doc("When set to true, the built-in ORC reader and writer are used to process " + "ORC tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index da9fe2d3088b4..1df46d7431a21 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -995,6 +995,8 @@ private[hive] object HiveClientImpl { tpart.setTableName(ht.getTableName) tpart.setValues(partValues.asJava) tpart.setSd(storageDesc) + tpart.setCreateTime((p.createTime / 1000).toInt) + tpart.setLastAccessTime((p.lastAccessTime / 1000).toInt) tpart.setParameters(mutable.Map(p.parameters.toSeq: _*).asJava) new HivePartition(ht, tpart) } @@ -1019,6 +1021,8 @@ private[hive] object HiveClientImpl { compressed = apiPartition.getSd.isCompressed, properties = Option(apiPartition.getSd.getSerdeInfo.getParameters) .map(_.asScala.toMap).orNull), + createTime = apiPartition.getCreateTime.toLong * 1000, + lastAccessTime = apiPartition.getLastAccessTime.toLong * 1000, parameters = properties, stats = readHiveStats(properties)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 948ba542b5733..933384ed43e98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -24,7 +24,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, S import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -46,7 +45,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.sql.types.{AtomicType, IntegralType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -343,7 +342,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { - conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000L } override def loadPartition( @@ -657,17 +656,32 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled + object ExtractAttribute { + def unapply(expr: Expression): Option[Attribute] = { + expr match { + case attr: Attribute => Some(attr) + case Cast(child @ AtomicType(), dt: AtomicType, _) + if Cast.canSafeCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child) + case _ => None + } + } + } + def convert(expr: Expression): Option[String] = expr match { - case In(NonVarcharAttribute(name), ExtractableLiterals(values)) if useAdvanced => + case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values)) + if useAdvanced => Some(convertInToOr(name, values)) - case InSet(NonVarcharAttribute(name), ExtractableValues(values)) if useAdvanced => + case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values)) + if useAdvanced => Some(convertInToOr(name, values)) - case op @ SpecialBinaryComparison(NonVarcharAttribute(name), ExtractableLiteral(value)) => + case op @ SpecialBinaryComparison( + ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiteral(value)) => Some(s"$name ${op.symbol} $value") - case op @ SpecialBinaryComparison(ExtractableLiteral(value), NonVarcharAttribute(name)) => + case op @ SpecialBinaryComparison( + ExtractableLiteral(value), ExtractAttribute(NonVarcharAttribute(name))) => Some(s"$value ${op.symbol} $name") case And(expr1, expr2) if useAdvanced => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index c2690ec32b9e7..6a90c44a2633d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -98,7 +98,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 case "2.2" | "2.2.0" => hive.v2_2 - case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" => hive.v2_3 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" => hive.v2_3 } private def downloadVersion( @@ -182,6 +182,7 @@ private[hive] class IsolatedClientLoader( name.startsWith("org.slf4j") || name.startsWith("org.apache.log4j") || // log4j1.x name.startsWith("org.apache.logging.log4j") || // log4j2 + name.startsWith("org.apache.derby.") || name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 681ee9200f02b..25e9886fa6576 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -75,7 +75,7 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - case object v2_3 extends HiveVersion("2.3.2", + case object v2_3 extends HiveVersion("2.3.3", exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 6a7b25b36d9a5..e0f7375387d24 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -122,7 +122,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { allSupportedHiveVersions) val externalCatalog = sparkSession.sharedState.externalCatalog - val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version + val hiveVersion = externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.version val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging") val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 237ed9bc05988..20090696ec3fc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration /** @@ -72,6 +72,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) val configuration = job.getConfiguration @@ -121,6 +122,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => @@ -174,6 +176,23 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 965aea2b61456..ee3f99ab7e9bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, ExternalCatalogWithListener} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand @@ -83,11 +84,11 @@ private[hive] class TestHiveSharedState( hiveClient: Option[HiveClient] = None) extends SharedState(sc) { - override lazy val externalCatalog: TestHiveExternalCatalog = { - new TestHiveExternalCatalog( + override lazy val externalCatalog: ExternalCatalogWithListener = { + new ExternalCatalogWithListener(new TestHiveExternalCatalog( sc.conf, sc.hadoopConfiguration, - hiveClient) + hiveClient)) } } @@ -208,7 +209,9 @@ private[hive] class TestHiveSparkSession( new TestHiveSessionStateBuilder(this, parentSessionState).build() } - lazy val metadataHive: HiveClient = sharedState.externalCatalog.client.newSession() + lazy val metadataHive: HiveClient = { + sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.newSession() + } override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 48ab4eb9a6178..569f00c053e5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -38,7 +38,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto val plan = table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index d10a6f25c64fc..4550d350f6db2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -268,12 +268,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo compressionCodecs = compressCodecs, tableCompressionCodecs = compressCodecs) { case (tableCodec, sessionCodec, realCodec, tableSize) => - // For non-partitioned table and when convertMetastore is true, Expect session-level - // take effect, and in other cases expect table-level take effect - // TODO: It should always be table-level taking effect when the bug(SPARK-22926) - // is fixed - val expectCodec = - if (convertMetastore && !isPartitioned) sessionCodec else tableCodec.get + val expectCodec = tableCodec.get assert(expectCodec == realCodec) assert(checkTableSize( format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 0a522b6a11c80..1de258f060943 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -113,4 +113,10 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) assert(catalog.getDatabase("dbWithNullDesc").description == "") } + + test("SPARK-23831: Add org.apache.derby to IsolatedClientLoader") { + val client1 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) + val client2 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) + assert(!client1.equals(client2)) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6ca58e68d31eb..514921875f1f9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -67,7 +67,21 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { logInfo(s"Downloading Spark $version from $url") try { getFileFromUrl(url, path, filename) - return + val downloaded = new File(sparkTestingDir, filename).getCanonicalPath + val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath + + Seq("mkdir", targetDir).! + val exitCode = Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! + Seq("rm", downloaded).! + + // For a corrupted file, `tar` returns non-zero values. However, we also need to check + // the extracted file because `tar` returns 0 for empty file. + val sparkSubmit = new File(sparkTestingDir, s"spark-$version/bin/spark-submit") + if (exitCode == 0 && sparkSubmit.exists()) { + return + } else { + Seq("rm", "-rf", targetDir).! + } } catch { case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } @@ -75,20 +89,6 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { fail(s"Unable to download Spark $version") } - - private def downloadSpark(version: String): Unit = { - tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) - - val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath - val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath - - Seq("mkdir", targetDir).! - - Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! - - Seq("rm", downloaded).! - } - private def genDataDir(name: String): String = { new File(tmpDataDir, name).getCanonicalPath } @@ -161,7 +161,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") if (!sparkHome.exists()) { - downloadSpark(version) + tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) } val args = Seq( @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1", "2.3.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.1") protected var spark: SparkSession = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala index 285f35b0b0eac..fd5f47e428239 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala @@ -26,7 +26,7 @@ class HiveExternalSessionCatalogSuite extends SessionCatalogSuite with TestHiveS private val externalCatalog = { val catalog = spark.sharedState.externalCatalog - catalog.asInstanceOf[HiveExternalCatalog].client.reset() + catalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.reset() catalog } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index f2d27671094d7..51a48a20daaa2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -50,7 +50,8 @@ class HiveSchemaInferenceSuite FileStatusCache.resetForTesting() } - private val externalCatalog = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private val externalCatalog = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] private val client = externalCatalog.client // Return a copy of the given schema with all field names converted to lower case. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index ecc09cdcdbeaf..a3579862c9e59 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -44,7 +44,8 @@ class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton { val conf = sparkSession.sparkContext.hadoopConfiguration val oldValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) sparkSession.cloneSession() - sparkSession.sharedState.externalCatalog.client.newSession() + sparkSession.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.newSession() val newValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) assert(oldValue == newValue, "cloneSession and then newSession should not affect the Derby directory") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 079fe45860544..aa5b531992613 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -354,7 +354,7 @@ object SetMetastoreURLTest extends Logging { // HiveExternalCatalog is used when Hive support is enabled. val actualMetastoreURL = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client .getConf("javax.jdo.option.ConnectionURL", "this_is_a_wrong_URL") logInfo(s"javax.jdo.option.ConnectionURL is $actualMetastoreURL") @@ -780,7 +780,8 @@ object SPARK_18360 { val defaultDbLocation = spark.catalog.getDatabase("default").locationUri assert(new Path(defaultDbLocation) == new Path(spark.sharedState.warehousePath)) - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val hiveClient = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client try { val tableMeta = CatalogTable( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala index 95f192f0e40e2..1e525c46a9cfb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -32,13 +33,22 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto import spark.implicits._ - before { + override def beforeAll(): Unit = { + super.beforeAll() sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)") (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)")) } + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS metadata_only") + } finally { + super.afterAll() + } + } + test("SPARK-23877: validate metadata-only query pushes filters to metastore") { - withTable("metadata_only") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the number of matching partitions @@ -50,7 +60,7 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto } test("SPARK-23877: filter on projected expression") { - withTable("metadata_only") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the matching partitions diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala index fad81c7e9474e..473bbced41b31 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -289,7 +289,8 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } private def createRawHiveTable(ddl: String): Unit = { - hiveContext.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client.runSqlHive(ddl) + hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.runSqlHive(ddl) } private def checkCreateTable(table: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index f991352b207d4..fa9f753795f65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -22,13 +22,13 @@ import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType} // TODO: Refactor this to `HivePartitionFilteringSuite` class HiveClientSuite(version: String) extends HiveVersionSuite(version) with BeforeAndAfterAll { - import CatalystSqlParser._ private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname @@ -46,8 +46,7 @@ class HiveClientSuite(version: String) val hadoopConf = new Configuration() hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) val client = buildClient(hadoopConf) - client - .runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") + client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") val partitions = for { @@ -66,6 +65,15 @@ class HiveClientSuite(version: String) client } + private def attr(name: String): Attribute = { + client.getTable("default", "test").partitionSchema.fields + .find(field => field.name.equals(name)) match { + case Some(field) => AttributeReference(field.name, field.dataType)() + case None => + fail(s"Illegal name of partition attribute: $name") + } + } + override def beforeAll() { super.beforeAll() client = init(true) @@ -74,7 +82,7 @@ class HiveClientSuite(version: String) test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { val client = init(false) val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), - Seq(parseExpression("ds=20170101"))) + Seq(attr("ds") === 20170101)) assert(filteredPartitions.size == testPartitionCount) } @@ -82,7 +90,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds<=>20170101") { // Should return all partitions where <=> is not supported testMetastorePartitionFiltering( - "ds<=>20170101", + attr("ds") <=> 20170101, 20170101 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -90,7 +98,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101") { testMetastorePartitionFiltering( - "ds=20170101", + attr("ds") === 20170101, 20170101 to 20170101, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -100,7 +108,7 @@ class HiveClientSuite(version: String) // Should return all partitions where h=0 because getPartitionsByFilter does not support // comparisons to non-literal values testMetastorePartitionFiltering( - "ds=(20170101 + 1) and h=0", + attr("ds") === (Literal(20170101) + 1) && attr("h") === 0, 20170101 to 20170103, 0 to 0, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -108,15 +116,31 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk='aa'") { testMetastorePartitionFiltering( - "chunk='aa'", + attr("chunk") === "aa", 20170101 to 20170103, 0 to 23, "aa" :: Nil) } + test("getPartitionsByFilter: cast(chunk as int)=1 (not a valid partition predicate)") { + testMetastorePartitionFiltering( + attr("chunk").cast(IntegerType) === 1, + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(chunk as boolean)=true (not a valid partition predicate)") { + testMetastorePartitionFiltering( + attr("chunk").cast(BooleanType) === true, + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + test("getPartitionsByFilter: 20170101=ds") { testMetastorePartitionFiltering( - "20170101=ds", + Literal(20170101) === attr("ds"), 20170101 to 20170101, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -124,7 +148,15 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101 and h=10") { testMetastorePartitionFiltering( - "ds=20170101 and h=10", + attr("ds") === 20170101 && attr("h") === 10, + 20170101 to 20170101, + 10 to 10, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(ds as long)=20170101L and h=10") { + testMetastorePartitionFiltering( + attr("ds").cast(LongType) === 20170101L && attr("h") === 10, 20170101 to 20170101, 10 to 10, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -132,7 +164,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101 or ds=20170102") { testMetastorePartitionFiltering( - "ds=20170101 or ds=20170102", + attr("ds") === 20170101 || attr("ds") === 20170102, 20170101 to 20170102, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -140,7 +172,15 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", + attr("ds").in(20170102, 20170103), + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using IN expression)") { + testMetastorePartitionFiltering( + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -148,7 +188,19 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", + attr("ds").in(20170102, 20170103), + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil, { + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) + }) + } + + test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using INSET expression)") + { + testMetastorePartitionFiltering( + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { @@ -159,7 +211,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil) @@ -167,7 +219,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil, { @@ -179,26 +231,24 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") { val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering( - "(ds=20170101 and h>=8) or (ds=20170102 and h<8)", - day1 :: day2 :: Nil) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || + (attr("ds") === 20170102 && attr("h") < 8), day1 :: day2 :: Nil) } test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) // Day 2 should include all hours because we can't build a filter for h<(7+1) val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering( - "(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))", - day1 :: day2 :: Nil) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || + (attr("ds") === 20170102 && attr("h") < (Literal(7) + 1)), day1 :: day2 :: Nil) } test("getPartitionsByFilter: " + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) - testMetastorePartitionFiltering( - "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))", + testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") && + ((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102 && attr("h") < 8)), day1 :: day2 :: Nil) } @@ -207,41 +257,41 @@ class HiveClientSuite(version: String) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String]): Unit = { testMetastorePartitionFiltering( - filterString, + filterExpr, (expectedDs, expectedH, expectedChunks) :: Nil, identity) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String], transform: Expression => Expression): Unit = { testMetastorePartitionFiltering( - filterString, + filterExpr, (expectedDs, expectedH, expectedChunks) :: Nil, - identity) + transform) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = { - testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity) + testMetastorePartitionFiltering(filterExpr, expectedPartitionCubes, identity) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])], transform: Expression => Expression): Unit = { val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), Seq( - transform(parseExpression(filterString)) + transform(filterExpr) )) val expectedPartitionCount = expectedPartitionCubes.map { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 6176273c88db1..dc96ec416afd8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -134,8 +134,8 @@ class VersionsSuite extends SparkFunSuite with Logging { client = buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) if (versionSpark != null) versionSpark.reset() versionSpark = TestHiveVersion(client) - assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - .version.fullVersion.startsWith(version)) + assert(versionSpark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.version.fullVersion.startsWith(version)) } def table(database: String, tableName: String): CatalogTable = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index c85db78c732de..0341c3b378918 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METAS import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -1354,7 +1355,8 @@ class HiveDDLSuite val indexName = tabName + "_index" withTable(tabName) { // Spark SQL does not support creating index. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client sql(s"CREATE TABLE $tabName(a int)") try { @@ -1392,7 +1394,8 @@ class HiveDDLSuite val tabName = "tab1" withTable(tabName) { // Spark SQL does not support creating skewed table. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client client.runSqlHive( s""" |CREATE Table $tabName(col1 int, col2 int) @@ -2144,6 +2147,86 @@ class HiveDDLSuite } } + private def getReader(path: String): org.apache.orc.Reader = { + val conf = spark.sessionState.newHadoopConf() + val files = org.apache.spark.sql.execution.datasources.orc.OrcUtils.listOrcFiles(path, conf) + assert(files.length == 1) + val file = files.head + val fs = file.getFileSystem(conf) + val readerOptions = org.apache.orc.OrcFile.readerOptions(conf).filesystem(fs) + org.apache.orc.OrcFile.createReader(file, readerOptions) + } + + test("SPARK-23355 convertMetastoreOrc should not ignore table properties - STORED AS") { + Seq("native", "hive").foreach { orcImpl => + withSQLConf(ORC_IMPLEMENTATION.key -> orcImpl, CONVERT_METASTORE_ORC.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS ORC + |TBLPROPERTIES ( + | orc.compress 'ZLIB', + | orc.compress.size '1001', + | orc.row.index.stride '2002', + | hive.exec.orc.default.block.size '3003', + | hive.exec.orc.compression.strategy 'COMPRESSION') + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("orc")) + val properties = table.properties + assert(properties.get("orc.compress") == Some("ZLIB")) + assert(properties.get("orc.compress.size") == Some("1001")) + assert(properties.get("orc.row.index.stride") == Some("2002")) + assert(properties.get("hive.exec.orc.default.block.size") == Some("3003")) + assert(properties.get("hive.exec.orc.compression.strategy") == Some("COMPRESSION")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + val reader = getReader(maybeFile.head.getCanonicalPath) + assert(reader.getCompressionKind.name === "ZLIB") + assert(reader.getCompressionSize == 1001) + assert(reader.getRowIndexStride == 2002) + } + } + } + } + } + + test("SPARK-23355 convertMetastoreParquet should not ignore table properties - STORED AS") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS PARQUET + |TBLPROPERTIES ( + | parquet.compression 'GZIP' + |) + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("parquet")) + val properties = table.properties + assert(properties.get("parquet.compression") == Some("GZIP")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + assertCompression(maybeFile, "parquet", "GZIP") + } + } + } + } + test("load command for non local invalid path validation") { withTable("tbl") { sql("CREATE TABLE tbl(i INT, j STRING)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 73f83d593bbfb..828c18a770c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2099,7 +2099,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Seq("orc", "parquet").foreach { format => test(s"SPARK-18355 Read data from a hive table with a new column - $format") { - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client Seq("true", "false").foreach { value => withSQLConf( @@ -2156,4 +2157,35 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-24085 scalar subquery in partitioning expression") { + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted", + "hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(format) { + withTempPath { tempDir => + sql( + s""" + |CREATE TABLE ${format} (id_value string) + |PARTITIONED BY (id_type string) + |LOCATION '${tempDir.toURI}' + |STORED AS ${format} + """.stripMargin) + sql(s"insert into $format values ('1','a')") + sql(s"insert into $format values ('2','a')") + sql(s"insert into $format values ('3','b')") + sql(s"insert into $format values ('4','b')") + checkAnswer( + sql(s"SELECT * FROM $format WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } + } + } + } + } + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index d556a030e2186..fb4957ed943a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.orc.OrcSuite import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { @@ -133,4 +135,42 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { Utils.deleteRecursively(location) } } + + test("SPARK-24204 error handling for unsupported data types") { + withTempDir { dir => + val orcDir = new File(dir, "orc").getCanonicalPath + + // write path + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("Cannot save interval data type into external storage.")) + + msg = intercept[AnalysisException] { + sql("select null").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("ORC data source does not support null data type.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new IntervalData()) + sql("select testType()").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("ORC data source does not support interval data type.")) + + // read path + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.mode("overwrite").orc(orcDir) + spark.read.schema(schema).orc(orcDir).collect() + }.getMessage + assert(msg.contains("ORC data source does not support calendarinterval data type.")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.mode("overwrite").orc(orcDir) + spark.read.schema(schema).orc(orcDir).collect() + }.getMessage + assert(msg.contains("ORC data source does not support interval data type.")) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index d3fff37c3424d..d50bf0b8fd603 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -30,7 +30,7 @@ trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected val hiveClient: HiveClient = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client protected override def afterAll(): Unit = { try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index dce5bb7ddba66..6858bbc441721 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -124,7 +124,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { test("SPARK-8604: Parquet data source should write summary file while doing appending") { withSQLConf( - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true", + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL", SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { withTempPath { dir => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index dacff69d55dd2..cf4324578ea87 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -112,10 +112,11 @@ private[streaming] class ReceivedBlockTracker( def allocateBlocksToBatch(batchTime: Time): Unit = synchronized { if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) { val streamIdToBlocks = streamIds.map { streamId => - (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) + (streamId, getReceivedBlockQueue(streamId).clone()) }.toMap val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) { + streamIds.foreach(getReceivedBlockQueue(_).clear()) timeToAllocatedBlocks.put(batchTime, allocatedBlocks) lastAllocatedBatchTime = batchTime } else { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 6748dd4ec48e3..884d21d0afdd3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -47,6 +47,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateJobRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, @@ -54,7 +55,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { isFirstRow: Boolean, jobIdWithData: SparkJobIdWithUIData): Seq[Node] = { if (jobIdWithData.jobData.isDefined) { - generateNormalJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, + generateNormalJobRow(request, outputOpData, outputOpDescription, formattedOutputOpDuration, numSparkJobRowsInOutputOp, isFirstRow, jobIdWithData.jobData.get) } else { generateDroppedJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, @@ -89,6 +90,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { * one cell, we use "rowspan" for the first row of an output op. */ private def generateNormalJobRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, @@ -106,7 +108,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { dropWhile(_.failureReason == None).take(1). // get the first info that contains failure flatMap(info => info.failureReason).headOption.getOrElse("") val formattedDuration = duration.map(d => SparkUIUtils.formatDuration(d)).getOrElse("-") - val detailUrl = s"${SparkUIUtils.prependBaseUri(parent.basePath)}/jobs/job?id=${sparkJob.jobId}" + val detailUrl = s"${SparkUIUtils.prependBaseUri( + request, parent.basePath)}/jobs/job/?id=${sparkJob.jobId}" // In the first row, output op id and its information needs to be shown. In other rows, these // cells will be taken up due to "rowspan". @@ -196,6 +199,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateOutputOpIdRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { val formattedOutputOpDuration = @@ -212,6 +216,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } else { val firstRow = generateJobRow( + request, outputOpData, description, formattedOutputOpDuration, @@ -221,6 +226,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { val tailRows = sparkJobs.tail.map { sparkJob => generateJobRow( + request, outputOpData, description, formattedOutputOpDuration, @@ -278,7 +284,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { /** * Generate the job table for the batch. */ - private def generateJobTable(batchUIData: BatchUIData): Seq[Node] = { + private def generateJobTable( + request: HttpServletRequest, + batchUIData: BatchUIData): Seq[Node] = { val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId). map { case (outputOpId, outputOpIdAndSparkJobIds) => // sort SparkJobIds for each OutputOpId @@ -301,7 +309,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { { outputOpWithJobs.map { case (outputOpData, sparkJobs) => - generateOutputOpIdRow(outputOpData, sparkJobs) + generateOutputOpIdRow(request, outputOpData, sparkJobs) } } @@ -364,9 +372,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {
      - val content = summary ++ generateJobTable(batchUIData) + val content = summary ++ generateJobTable(request, batchUIData) - SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) + SparkUIUtils.headerSparkPage( + request, s"Details of batch at $formattedBatchTime", content, parent) } def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 3a176f64cdd60..4ce661bc1144e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -148,7 +148,7 @@ private[ui] class StreamingPage(parent: StreamingTab) /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val resources = generateLoadResources() + val resources = generateLoadResources(request) val basicInfo = generateBasicInfo() val content = resources ++ basicInfo ++ @@ -156,17 +156,17 @@ private[ui] class StreamingPage(parent: StreamingTab) generateStatTable() ++ generateBatchListTables() } - SparkUIUtils.headerSparkPage("Streaming Statistics", content, parent, Some(5000)) + SparkUIUtils.headerSparkPage(request, "Streaming Statistics", content, parent, Some(5000)) } /** * Generate html that will load css/js files for StreamingPage */ - private def generateLoadResources(): Seq[Node] = { + private def generateLoadResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off - - - + + + // scalastyle:on } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 9d1b82a6341b1..25e71258b9369 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -49,7 +49,7 @@ private[spark] class StreamingTab(val ssc: StreamingContext) def detach() { getSparkUI(ssc).detachTab(this) - getSparkUI(ssc).removeStaticHandler("/static/streaming") + getSparkUI(ssc).detachHandler("/static/streaming") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index ab7c8558321c8..2e8599026ea1d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -222,7 +222,7 @@ private[streaming] class FileBasedWriteAheadLog( pastLogs += LogInfo(currentLogWriterStartTime, currentLogWriterStopTime, _) } currentLogWriterStartTime = currentTime - currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000) + currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000L) val newLogPath = new Path(logDirectory, timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) currentLogPath = Some(newLogPath.toString) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 4fa236bd39663..fd7e00b1de25f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -26,10 +26,12 @@ import scala.language.{implicitConversions, postfixOps} import scala.util.Random import org.apache.hadoop.conf.Configuration +import org.mockito.Matchers.any +import org.mockito.Mockito.{doThrow, reset, spy} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult @@ -115,6 +117,47 @@ class ReceivedBlockTrackerSuite tracker2.stop() } + test("block allocation to batch should not loose blocks from received queue") { + val tracker1 = spy(createTracker()) + tracker1.isWriteAheadLogEnabled should be (true) + tracker1.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + + // Add blocks + val blockInfos = generateBlockInfos() + blockInfos.map(tracker1.addBlock) + tracker1.getUnallocatedBlocks(streamId) shouldEqual blockInfos + + // Try to allocate the blocks to a batch and verify that it's failing + // The blocks should stay in the received queue when WAL write failing + doThrow(new RuntimeException("Not able to write BatchAllocationEvent")) + .when(tracker1).writeToLog(any(classOf[BatchAllocationEvent])) + val errMsg = intercept[RuntimeException] { + tracker1.allocateBlocksToBatch(1) + } + assert(errMsg.getMessage === "Not able to write BatchAllocationEvent") + tracker1.getUnallocatedBlocks(streamId) shouldEqual blockInfos + tracker1.getBlocksOfBatch(1) shouldEqual Map.empty + tracker1.getBlocksOfBatchAndStream(1, streamId) shouldEqual Seq.empty + + // Allocate the blocks to a batch and verify that all of them have been allocated + reset(tracker1) + tracker1.allocateBlocksToBatch(2) + tracker1.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + tracker1.hasUnallocatedReceivedBlocks should be (false) + tracker1.getBlocksOfBatch(2) shouldEqual Map(streamId -> blockInfos) + tracker1.getBlocksOfBatchAndStream(2, streamId) shouldEqual blockInfos + + tracker1.stop() + + // Recover from WAL to see the correctness + val tracker2 = createTracker(recoverFromWriteAheadLog = true) + tracker2.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + tracker2.hasUnallocatedReceivedBlocks should be (false) + tracker2.getBlocksOfBatch(2) shouldEqual Map(streamId -> blockInfos) + tracker2.getBlocksOfBatchAndStream(2, streamId) shouldEqual blockInfos + tracker2.stop() + } + test("recovery and cleanup with write ahead logs") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates @@ -312,7 +355,7 @@ class ReceivedBlockTrackerSuite recoverFromWriteAheadLog: Boolean = false, clock: Clock = new SystemClock): ReceivedBlockTracker = { val cpDirOption = if (setCheckpointDir) Some(checkpointDirectory.toString) else None - val tracker = new ReceivedBlockTracker( + var tracker = new ReceivedBlockTracker( conf, hadoopConf, Seq(streamId), clock, recoverFromWriteAheadLog, cpDirOption) allReceivedBlockTrackers += tracker tracker