From f0ee986c1074bddc2b67a7b270136a8040675c8d Mon Sep 17 00:00:00 2001 From: Michele Rastelli Date: Wed, 31 May 2023 14:02:14 +0200 Subject: [PATCH] [DE-598] Spark 3.4 (#45) * Spark 3.4 * test fixes * upd sonar spark-version * test fixes --- .github/workflows/maven-deploy.yml | 20 + .github/workflows/maven-release.yml | 8 + .github/workflows/test.yml | 35 +- arangodb-spark-datasource-3.4/pom.xml | 79 +++ ...db.commons.mapping.ArangoGeneratorProvider | 1 + ...ngodb.commons.mapping.ArangoParserProvider | 1 + ...pache.spark.sql.sources.DataSourceRegister | 1 + .../com/arangodb/spark/DefaultSource.scala | 40 ++ .../sql/arangodb/datasource/ArangoTable.scala | 37 ++ .../mapping/ArangoGeneratorImpl.scala | 46 ++ .../datasource/mapping/ArangoParserImpl.scala | 47 ++ .../mapping/json/CreateJacksonParser.scala | 95 +++ .../datasource/mapping/json/JSONOptions.scala | 268 ++++++++ .../mapping/json/JacksonGenerator.scala | 327 ++++++++++ .../mapping/json/JacksonParser.scala | 589 ++++++++++++++++++ .../mapping/json/JacksonUtils.scala | 64 ++ .../datasource/mapping/json/JsonFilters.scala | 158 +++++ .../mapping/json/JsonInferSchema.scala | 413 ++++++++++++ .../arangodb/datasource/mapping/package.scala | 14 + .../reader/ArangoCollectionPartition.scala | 15 + .../ArangoCollectionPartitionReader.scala | 65 ++ .../reader/ArangoPartitionReaderFactory.scala | 13 + .../datasource/reader/ArangoQueryReader.scala | 63 ++ .../datasource/reader/ArangoScan.scala | 28 + .../datasource/reader/ArangoScanBuilder.scala | 66 ++ .../datasource/writer/ArangoBatchWriter.scala | 30 + .../datasource/writer/ArangoDataWriter.scala | 136 ++++ .../writer/ArangoDataWriterFactory.scala | 12 + .../writer/ArangoWriterBuilder.scala | 93 +++ bin/clean.sh | 2 + bin/test.sh | 6 + demo/README.md | 2 +- demo/docker/start_spark_3.2.sh | 8 +- demo/pom.xml | 15 +- docker/start_spark_2.4.sh | 7 - docker/start_spark_3.1.sh | 7 - docker/stop.sh | 10 - .../datasource/DeserializationCastTest.scala | 3 + .../arangodb/datasource/write/AbortTest.scala | 23 +- .../datasource/write/OverwriteModeTest.scala | 3 +- pom.xml | 9 + 41 files changed, 2817 insertions(+), 42 deletions(-) create mode 100644 arangodb-spark-datasource-3.4/pom.xml create mode 100644 arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider create mode 100644 arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider create mode 100644 arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/com/arangodb/spark/DefaultSource.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/ArangoTable.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoGeneratorImpl.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/CreateJacksonParser.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JSONOptions.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonGenerator.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonParser.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonUtils.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonFilters.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonInferSchema.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/package.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartition.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoPartitionReaderFactory.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoBatchWriter.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriter.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriterFactory.scala create mode 100644 arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoWriterBuilder.scala delete mode 100755 docker/start_spark_2.4.sh delete mode 100755 docker/start_spark_3.1.sh delete mode 100755 docker/stop.sh diff --git a/.github/workflows/maven-deploy.yml b/.github/workflows/maven-deploy.yml index 76c6a040..4afab33b 100644 --- a/.github/workflows/maven-deploy.yml +++ b/.github/workflows/maven-deploy.yml @@ -12,6 +12,26 @@ jobs: strategy: fail-fast: false + matrix: + include: + - scala-version: 2.11 + spark-version: 2.4 + - scala-version: 2.12 + spark-version: 2.4 + - scala-version: 2.12 + spark-version: 3.1 + - scala-version: 2.12 + spark-version: 3.2 + - scala-version: 2.13 + spark-version: 3.2 + - scala-version: 2.12 + spark-version: 3.3 + - scala-version: 2.13 + spark-version: 3.3 + - scala-version: 2.12 + spark-version: 3.4 + - scala-version: 2.13 + spark-version: 3.4 steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/maven-release.yml b/.github/workflows/maven-release.yml index 60ce229d..094fea9a 100644 --- a/.github/workflows/maven-release.yml +++ b/.github/workflows/maven-release.yml @@ -24,6 +24,14 @@ jobs: spark-version: 3.2 - scala-version: 2.13 spark-version: 3.2 + - scala-version: 2.12 + spark-version: 3.3 + - scala-version: 2.13 + spark-version: 3.3 + - scala-version: 2.12 + spark-version: 3.4 + - scala-version: 2.13 + spark-version: 3.4 steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4462c105..e0c3d926 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,6 +38,8 @@ jobs: - 2.4 - 3.1 - 3.2 + - 3.3 + - 3.4 topology: - single - cluster @@ -53,12 +55,18 @@ jobs: spark-version: 3.1 - scala-version: 2.11 spark-version: 3.2 + - scala-version: 2.11 + spark-version: 3.3 + - scala-version: 2.11 + spark-version: 3.4 - scala-version: 2.11 java-version: 11 - scala-version: 2.13 spark-version: 2.4 - scala-version: 2.13 spark-version: 3.1 + - docker-img: docker.io/arangodb/arangodb:3.9.10 + java-version: 8 - docker-img: docker.io/arangodb/arangodb:3.10.6 java-version: 8 - docker-img: docker.io/arangodb/arangodb:3.11.0 @@ -96,6 +104,8 @@ jobs: - 2.4 - 3.1 - 3.2 + - 3.3 + - 3.4 topology: - cluster java-version: @@ -107,6 +117,10 @@ jobs: spark-version: 3.1 - scala-version: 2.11 spark-version: 3.2 + - scala-version: 2.11 + spark-version: 3.3 + - scala-version: 2.11 + spark-version: 3.4 - scala-version: 2.13 spark-version: 2.4 - scala-version: 2.13 @@ -140,10 +154,15 @@ jobs: matrix: python-version: [3.9] scala-version: [2.12] - spark-version: [3.1, 3.2] + spark-version: [3.1, 3.2, 3.3, 3.4] topology: [single, cluster] java-version: [8, 11] docker-img: ["docker.io/arangodb/arangodb:3.11.0"] + exclude: + - topology: cluster + java-version: 8 + - topology: single + java-version: 11 steps: - uses: actions/checkout@v2 @@ -191,6 +210,8 @@ jobs: - 2.4 - 3.1 - 3.2 + - 3.3 + - 3.4 topology: - single java-version: @@ -203,6 +224,10 @@ jobs: spark-version: 3.1 - scala-version: 2.11 spark-version: 3.2 + - scala-version: 2.11 + spark-version: 3.3 + - scala-version: 2.11 + spark-version: 3.4 - scala-version: 2.13 spark-version: 2.4 - scala-version: 2.13 @@ -301,6 +326,12 @@ jobs: - spark-version: 3.3 scala-version: 2.13 spark-full-version: 3.3.2 + - spark-version: 3.4 + scala-version: 2.12 + spark-full-version: 3.4.0 + - spark-version: 3.4 + scala-version: 2.13 + spark-full-version: 3.4.0 steps: - uses: actions/checkout@v2 @@ -331,7 +362,7 @@ jobs: scala-version: - 2.12 spark-version: - - 3.2 + - 3.4 topology: - single java-version: diff --git a/arangodb-spark-datasource-3.4/pom.xml b/arangodb-spark-datasource-3.4/pom.xml new file mode 100644 index 00000000..4cbc5c6a --- /dev/null +++ b/arangodb-spark-datasource-3.4/pom.xml @@ -0,0 +1,79 @@ + + + + arangodb-spark-datasource + com.arangodb + 1.4.3 + + 4.0.0 + + arangodb-spark-datasource-3.4_${scala.compat.version} + + arangodb-spark-datasource-3.4 + ArangoDB Datasource for Apache Spark 3.4 + https://github.com/arangodb/arangodb-spark-datasource + + + + Michele Rastelli + https://github.com/rashtao + + + + + https://github.com/arangodb/arangodb-spark-datasource + + + + false + ../integration-tests/target/site/jacoco-aggregate/jacoco.xml + src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/* + src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/* + false + + + + + com.arangodb + arangodb-spark-commons-${spark.compat.version}_${scala.compat.version} + ${project.version} + + + org.apache.httpcomponents + httpclient + 4.5.13 + + + + + + + maven-assembly-plugin + + + jar-with-dependencies + + + + + package + + single + + + + + + org.sonatype.plugins + nexus-staging-maven-plugin + true + + false + + + + + + \ No newline at end of file diff --git a/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider b/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider new file mode 100644 index 00000000..477374e3 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoGeneratorProvider @@ -0,0 +1 @@ +org.apache.spark.sql.arangodb.datasource.mapping.ArangoGeneratorProviderImpl \ No newline at end of file diff --git a/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider b/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider new file mode 100644 index 00000000..3e6a6b92 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider @@ -0,0 +1 @@ +org.apache.spark.sql.arangodb.datasource.mapping.ArangoParserProviderImpl diff --git a/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 00000000..5a634481 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +com.arangodb.spark.DefaultSource diff --git a/arangodb-spark-datasource-3.4/src/main/scala/com/arangodb/spark/DefaultSource.scala b/arangodb-spark-datasource-3.4/src/main/scala/com/arangodb/spark/DefaultSource.scala new file mode 100644 index 00000000..38c0925c --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/com/arangodb/spark/DefaultSource.scala @@ -0,0 +1,40 @@ +package com.arangodb.spark + +import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf} +import org.apache.spark.sql.arangodb.datasource.ArangoTable +import org.apache.spark.sql.connector.catalog.{Table, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util + +class DefaultSource extends TableProvider with DataSourceRegister { + + private def extractOptions(options: util.Map[String, String]): ArangoDBConf = { + val opts: ArangoDBConf = ArangoDBConf(options) + if (opts.driverOptions.acquireHostList) { + val hosts = ArangoClient.acquireHostList(opts) + opts.updated(ArangoDBConf.ENDPOINTS, hosts.mkString(",")) + } else { + opts + } + } + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = getTable(options).schema() + + private def getTable(options: CaseInsensitiveStringMap): Table = + getTable(None, options.asCaseSensitiveMap()) // scalastyle:ignore null + + override def getTable(schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String]): Table = + getTable(Option(schema), properties) + + override def supportsExternalMetadata(): Boolean = true + + override def shortName(): String = "arangodb" + + private def getTable(schema: Option[StructType], properties: util.Map[String, String]) = + new ArangoTable(schema, extractOptions(properties)) + +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/ArangoTable.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/ArangoTable.scala new file mode 100644 index 00000000..e8f4d5a8 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/ArangoTable.scala @@ -0,0 +1,37 @@ +package org.apache.spark.sql.arangodb.datasource + +import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ArangoUtils} +import org.apache.spark.sql.arangodb.datasource.reader.ArangoScanBuilder +import org.apache.spark.sql.arangodb.datasource.writer.ArangoWriterBuilder +import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util +import scala.collection.JavaConverters.setAsJavaSetConverter + +class ArangoTable(private var schemaOpt: Option[StructType], options: ArangoDBConf) extends Table with SupportsRead with SupportsWrite { + private lazy val tableSchema = schemaOpt.getOrElse(ArangoUtils.inferSchema(options)) + + override def name(): String = this.getClass.toString + + override def schema(): StructType = tableSchema + + override def capabilities(): util.Set[TableCapability] = Set( + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + // TableCapability.STREAMING_WRITE, + TableCapability.ACCEPT_ANY_SCHEMA, + TableCapability.TRUNCATE + // TableCapability.OVERWRITE_BY_FILTER, + // TableCapability.OVERWRITE_DYNAMIC, + ).asJava + + override def newScanBuilder(scanOptions: CaseInsensitiveStringMap): ScanBuilder = + new ArangoScanBuilder(options.updated(ArangoDBConf(scanOptions)), schema()) + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = + new ArangoWriterBuilder(info.schema(), options.updated(ArangoDBConf(info.options()))) +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoGeneratorImpl.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoGeneratorImpl.scala new file mode 100644 index 00000000..b4c07882 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoGeneratorImpl.scala @@ -0,0 +1,46 @@ +package org.apache.spark.sql.arangodb.datasource.mapping + +import com.arangodb.jackson.dataformat.velocypack.VPackFactoryBuilder +import com.fasterxml.jackson.core.JsonFactoryBuilder +import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType} +import org.apache.spark.sql.arangodb.commons.mapping.{ArangoGenerator, ArangoGeneratorProvider} +import org.apache.spark.sql.arangodb.datasource.mapping.json.{JSONOptions, JacksonGenerator} +import org.apache.spark.sql.types.{DataType, StructType} + +import java.io.OutputStream + +abstract sealed class ArangoGeneratorImpl( + schema: DataType, + writer: OutputStream, + options: JSONOptions) + extends JacksonGenerator( + schema, + options.buildJsonFactory().createGenerator(writer), + options) with ArangoGenerator + +class ArangoGeneratorProviderImpl extends ArangoGeneratorProvider { + override def of( + contentType: ContentType, + schema: StructType, + outputStream: OutputStream, + conf: ArangoDBConf + ): ArangoGeneratorImpl = contentType match { + case ContentType.JSON => new JsonArangoGenerator(schema, outputStream, conf) + case ContentType.VPACK => new VPackArangoGenerator(schema, outputStream, conf) + case _ => throw new IllegalArgumentException + } +} + +class JsonArangoGenerator(schema: StructType, outputStream: OutputStream, conf: ArangoDBConf) + extends ArangoGeneratorImpl( + schema, + outputStream, + createOptions(new JsonFactoryBuilder().build(), conf) + ) + +class VPackArangoGenerator(schema: StructType, outputStream: OutputStream, conf: ArangoDBConf) + extends ArangoGeneratorImpl( + schema, + outputStream, + createOptions(new VPackFactoryBuilder().build(), conf) + ) diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala new file mode 100644 index 00000000..dad564ce --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/ArangoParserImpl.scala @@ -0,0 +1,47 @@ +package org.apache.spark.sql.arangodb.datasource.mapping + +import com.arangodb.jackson.dataformat.velocypack.VPackFactoryBuilder +import com.fasterxml.jackson.core.json.JsonReadFeature +import com.fasterxml.jackson.core.{JsonFactory, JsonFactoryBuilder} +import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ContentType} +import org.apache.spark.sql.arangodb.commons.mapping.{ArangoParser, ArangoParserProvider, MappingUtils} +import org.apache.spark.sql.arangodb.datasource.mapping.json.{JSONOptions, JacksonParser} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.DataType +import org.apache.spark.unsafe.types.UTF8String + +abstract sealed class ArangoParserImpl( + schema: DataType, + options: JSONOptions, + recordLiteral: Array[Byte] => UTF8String) + extends JacksonParser(schema, options) with ArangoParser { + override def parse(data: Array[Byte]): Iterable[InternalRow] = super.parse( + data, + (jsonFactory: JsonFactory, record: Array[Byte]) => jsonFactory.createParser(record), + recordLiteral + ) +} + +class ArangoParserProviderImpl extends ArangoParserProvider { + override def of(contentType: ContentType, schema: DataType, conf: ArangoDBConf): ArangoParserImpl = contentType match { + case ContentType.JSON => new JsonArangoParser(schema, conf) + case ContentType.VPACK => new VPackArangoParser(schema, conf) + case _ => throw new IllegalArgumentException + } +} + +class JsonArangoParser(schema: DataType, conf: ArangoDBConf) + extends ArangoParserImpl( + schema, + createOptions(new JsonFactoryBuilder() + .configure(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS, true) + .build(), conf), + (bytes: Array[Byte]) => UTF8String.fromBytes(bytes) + ) + +class VPackArangoParser(schema: DataType, conf: ArangoDBConf) + extends ArangoParserImpl( + schema, + createOptions(new VPackFactoryBuilder().build(), conf), + (bytes: Array[Byte]) => UTF8String.fromString(MappingUtils.vpackToJson(bytes)) + ) diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/CreateJacksonParser.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/CreateJacksonParser.scala new file mode 100644 index 00000000..0fa095f1 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/CreateJacksonParser.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off + +package org.apache.spark.sql.arangodb.datasource.mapping.json + +import com.fasterxml.jackson.core.{JsonFactory, JsonParser} +import org.apache.hadoop.io.Text +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.UTF8String +import sun.nio.cs.StreamDecoder + +import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} +import java.nio.channels.Channels +import java.nio.charset.{Charset, StandardCharsets} + +private[sql] object CreateJacksonParser extends Serializable { + def string(jsonFactory: JsonFactory, record: String): JsonParser = { + jsonFactory.createParser(record) + } + + def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = { + val bb = record.getByteBuffer + assert(bb.hasArray) + + val bain = new ByteArrayInputStream( + bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + + jsonFactory.createParser(new InputStreamReader(bain, StandardCharsets.UTF_8)) + } + + def text(jsonFactory: JsonFactory, record: Text): JsonParser = { + jsonFactory.createParser(record.getBytes, 0, record.getLength) + } + + // Jackson parsers can be ranked according to their performance: + // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser + // but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically + // by checking leading bytes of the array. + // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected + // automatically by analyzing first bytes of the input stream. + // 3. Reader based parser. This is the slowest parser used here but it allows to create + // a reader with specific encoding. + // The method creates a reader for an array with given encoding and sets size of internal + // decoding buffer according to size of input array. + private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = { + val bais = new ByteArrayInputStream(in, 0, length) + val byteChannel = Channels.newChannel(bais) + val decodingBufferSize = Math.min(length, 8192) + val decoder = Charset.forName(enc).newDecoder() + + StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize) + } + + def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = { + val sd = getStreamDecoder(enc, record.getBytes, record.getLength) + jsonFactory.createParser(sd) + } + + def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = { + jsonFactory.createParser(is) + } + + def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = { + jsonFactory.createParser(new InputStreamReader(is, enc)) + } + + def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = { + val ba = row.getBinary(0) + + jsonFactory.createParser(ba, 0, ba.length) + } + + def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = { + val binary = row.getBinary(0) + val sd = getStreamDecoder(enc, binary, binary.length) + + jsonFactory.createParser(sd) + } +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JSONOptions.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JSONOptions.scala new file mode 100644 index 00000000..a9034cbe --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JSONOptions.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off + +package org.apache.spark.sql.arangodb.datasource.mapping.json + +import com.fasterxml.jackson.core.json.JsonReadFeature +import com.fasterxml.jackson.core.{JsonFactory, JsonFactoryBuilder} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy + +import java.nio.charset.{Charset, StandardCharsets} +import java.time.ZoneId +import java.util.Locale + +/** + * Options for parsing JSON data into Spark SQL rows. + * + * Most of these map directly to Jackson's internal options, specified in [[JsonReadFeature]]. + */ +private[sql] class JSONOptions( + @transient val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) + extends FileSourceOptions(parameters) with Logging { + + import JSONOptions._ + + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } + + val samplingRatio = + parameters.get(SAMPLING_RATIO).map(_.toDouble).getOrElse(1.0) + val primitivesAsString = + parameters.get(PRIMITIVES_AS_STRING).map(_.toBoolean).getOrElse(false) + val prefersDecimal = + parameters.get(PREFERS_DECIMAL).map(_.toBoolean).getOrElse(false) + val allowComments = + parameters.get(ALLOW_COMMENTS).map(_.toBoolean).getOrElse(false) + val allowUnquotedFieldNames = + parameters.get(ALLOW_UNQUOTED_FIELD_NAMES).map(_.toBoolean).getOrElse(false) + val allowSingleQuotes = + parameters.get(ALLOW_SINGLE_QUOTES).map(_.toBoolean).getOrElse(true) + val allowNumericLeadingZeros = + parameters.get(ALLOW_NUMERIC_LEADING_ZEROS).map(_.toBoolean).getOrElse(false) + val allowNonNumericNumbers = + parameters.get(ALLOW_NON_NUMERIC_NUMBERS).map(_.toBoolean).getOrElse(true) + val allowBackslashEscapingAnyCharacter = + parameters.get(ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER).map(_.toBoolean).getOrElse(false) + private val allowUnquotedControlChars = + parameters.get(ALLOW_UNQUOTED_CONTROL_CHARS).map(_.toBoolean).getOrElse(false) + val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName) + val parseMode: ParseMode = + parameters.get(MODE).map(ParseMode.fromString).getOrElse(PermissiveMode) + val columnNameOfCorruptRecord = + parameters.getOrElse(COLUMN_NAME_OF_CORRUPTED_RECORD, defaultColumnNameOfCorruptRecord) + + // Whether to ignore column of all null values or empty array/struct during schema inference + val dropFieldIfAllNull = parameters.get(DROP_FIELD_IF_ALL_NULL).map(_.toBoolean).getOrElse(false) + + // Whether to ignore null fields during json generating + val ignoreNullFields = parameters.get(IGNORE_NULL_FIELDS).map(_.toBoolean) + .getOrElse(SQLConf.get.jsonGeneratorIgnoreNullFields) + + // If this is true, when writing NULL values to columns of JSON tables with explicit DEFAULT + // values, never skip writing the NULL values to storage, overriding 'ignoreNullFields' above. + // This can be useful to enforce that inserted NULL values are present in storage to differentiate + // from missing data. + val writeNullIfWithDefaultValue = SQLConf.get.jsonWriteNullIfWithDefaultValue + + // A language tag in IETF BCP 47 format + val locale: Locale = parameters.get(LOCALE).map(Locale.forLanguageTag).getOrElse(Locale.US) + + val zoneId: ZoneId = DateTimeUtils.getZoneId( + parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) + + val dateFormatInRead: Option[String] = parameters.get(DATE_FORMAT) + val dateFormatInWrite: String = parameters.getOrElse(DATE_FORMAT, DateFormatter.defaultPattern) + + val timestampFormatInRead: Option[String] = + if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { + Some(parameters.getOrElse(TIMESTAMP_FORMAT, + s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX")) + } else { + parameters.get(TIMESTAMP_FORMAT) + } + val timestampFormatInWrite: String = parameters.getOrElse(TIMESTAMP_FORMAT, + if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { + s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX" + } else { + s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS][XXX]" + }) + + val timestampNTZFormatInRead: Option[String] = parameters.get(TIMESTAMP_NTZ_FORMAT) + val timestampNTZFormatInWrite: String = + parameters.getOrElse(TIMESTAMP_NTZ_FORMAT, s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]") + + // SPARK-39731: Enables the backward compatible parsing behavior. + // Generally, this config should be set to false to avoid producing potentially incorrect results + // which is the current default (see JacksonParser). + // + // If enabled and the date cannot be parsed, we will fall back to `DateTimeUtils.stringToDate`. + // If enabled and the timestamp cannot be parsed, `DateTimeUtils.stringToTimestamp` will be used. + // Otherwise, depending on the parser policy and a custom pattern, an exception may be thrown and + // the value will be parsed as null. + val enableDateTimeParsingFallback: Option[Boolean] = + parameters.get(ENABLE_DATETIME_PARSING_FALLBACK).map(_.toBoolean) + + val multiLine = parameters.get(MULTI_LINE).map(_.toBoolean).getOrElse(false) + + /** + * A string between two consecutive JSON records. + */ + val lineSeparator: Option[String] = parameters.get(LINE_SEP).map { sep => + require(sep.nonEmpty, "'lineSep' cannot be an empty string.") + sep + } + + protected def checkedEncoding(enc: String): String = enc + + /** + * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. + * If the encoding is not specified (None) in read, it will be detected automatically + * when the multiLine option is set to `true`. If encoding is not specified in write, + * UTF-8 is used by default. + */ + val encoding: Option[String] = parameters.get(ENCODING) + .orElse(parameters.get(CHARSET)).map(checkedEncoding) + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.getOrElse(StandardCharsets.UTF_8.name())) + } + val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") + + /** + * Generating JSON strings in pretty representation if the parameter is enabled. + */ + val pretty: Boolean = parameters.get(PRETTY).map(_.toBoolean).getOrElse(false) + + /** + * Enables inferring of TimestampType and TimestampNTZType from strings matched to the + * corresponding timestamp pattern defined by the timestampFormat and timestampNTZFormat options + * respectively. + */ + val inferTimestamp: Boolean = parameters.get(INFER_TIMESTAMP).map(_.toBoolean).getOrElse(false) + + /** + * Generating \u0000 style codepoints for non-ASCII characters if the parameter is enabled. + */ + val writeNonAsciiCharacterAsCodePoint: Boolean = + parameters.get(WRITE_NON_ASCII_CHARACTER_AS_CODEPOINT).map(_.toBoolean).getOrElse(false) + + /** Build a Jackson [[JsonFactory]] using JSON options. */ + def buildJsonFactory(): JsonFactory = { + new JsonFactoryBuilder() + .configure(JsonReadFeature.ALLOW_JAVA_COMMENTS, allowComments) + .configure(JsonReadFeature.ALLOW_UNQUOTED_FIELD_NAMES, allowUnquotedFieldNames) + .configure(JsonReadFeature.ALLOW_SINGLE_QUOTES, allowSingleQuotes) + .configure(JsonReadFeature.ALLOW_LEADING_ZEROS_FOR_NUMBERS, allowNumericLeadingZeros) + .configure(JsonReadFeature.ALLOW_NON_NUMERIC_NUMBERS, allowNonNumericNumbers) + .configure( + JsonReadFeature.ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER, + allowBackslashEscapingAnyCharacter) + .configure(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS, allowUnquotedControlChars) + .build() + } +} + +private[sql] class JSONOptionsInRead( + @transient override val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) + extends JSONOptions(parameters, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) { + + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } + + protected override def checkedEncoding(enc: String): String = { + val isDenied = JSONOptionsInRead.denyList.contains(Charset.forName(enc)) + require(multiLine || !isDenied, + s"""The $enc encoding must not be included in the denyList when multiLine is disabled: + |denylist: ${JSONOptionsInRead.denyList.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } +} + +private[sql] object JSONOptionsInRead { + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val denyList = Seq( + Charset.forName("UTF-16"), + Charset.forName("UTF-32") + ) +} + +object JSONOptions extends DataSourceOptions { + val SAMPLING_RATIO = newOption("samplingRatio") + val PRIMITIVES_AS_STRING = newOption("primitivesAsString") + val PREFERS_DECIMAL = newOption("prefersDecimal") + val ALLOW_COMMENTS = newOption("allowComments") + val ALLOW_UNQUOTED_FIELD_NAMES = newOption("allowUnquotedFieldNames") + val ALLOW_SINGLE_QUOTES = newOption("allowSingleQuotes") + val ALLOW_NUMERIC_LEADING_ZEROS = newOption("allowNumericLeadingZeros") + val ALLOW_NON_NUMERIC_NUMBERS = newOption("allowNonNumericNumbers") + val ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER = newOption("allowBackslashEscapingAnyCharacter") + val ALLOW_UNQUOTED_CONTROL_CHARS = newOption("allowUnquotedControlChars") + val COMPRESSION = newOption("compression") + val MODE = newOption("mode") + val DROP_FIELD_IF_ALL_NULL = newOption("dropFieldIfAllNull") + val IGNORE_NULL_FIELDS = newOption("ignoreNullFields") + val LOCALE = newOption("locale") + val DATE_FORMAT = newOption("dateFormat") + val TIMESTAMP_FORMAT = newOption("timestampFormat") + val TIMESTAMP_NTZ_FORMAT = newOption("timestampNTZFormat") + val ENABLE_DATETIME_PARSING_FALLBACK = newOption("enableDateTimeParsingFallback") + val MULTI_LINE = newOption("multiLine") + val LINE_SEP = newOption("lineSep") + val PRETTY = newOption("pretty") + val INFER_TIMESTAMP = newOption("inferTimestamp") + val COLUMN_NAME_OF_CORRUPTED_RECORD = newOption("columnNameOfCorruptRecord") + val TIME_ZONE = newOption("timeZone") + val WRITE_NON_ASCII_CHARACTER_AS_CODEPOINT = newOption("writeNonAsciiCharacterAsCodePoint") + // Options with alternative + val ENCODING = "encoding" + val CHARSET = "charset" + newOption(ENCODING, CHARSET) +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonGenerator.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonGenerator.scala new file mode 100644 index 00000000..c28f788b --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonGenerator.scala @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off + +package org.apache.spark.sql.arangodb.datasource.mapping.json + +import com.fasterxml.jackson.core._ +import com.fasterxml.jackson.core.util.DefaultPrettyPrinter +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types._ + +import java.io.Writer + +/** + * `JackGenerator` can only be initialized with a `StructType`, a `MapType` or an `ArrayType`. + * Once it is initialized with `StructType`, it can be used to write out a struct or an array of + * struct. Once it is initialized with `MapType`, it can be used to write out a map or an array + * of map. An exception will be thrown if trying to write out a struct if it is initialized with + * a `MapType`, and vice verse. + */ +private[sql] class JacksonGenerator( + dataType: DataType, + generator: JsonGenerator, + options: JSONOptions) { + + def this(dataType: DataType, + writer: Writer, + options: JSONOptions) { + this( + dataType, + options.buildJsonFactory().createGenerator(writer).setRootValueSeparator(null), + options) + } + + // A `ValueWriter` is responsible for writing a field of an `InternalRow` to appropriate + // JSON data. Here we are using `SpecializedGetters` rather than `InternalRow` so that + // we can directly access data in `ArrayData` without the help of `SpecificMutableRow`. + private type ValueWriter = (SpecializedGetters, Int) => Unit + + // `JackGenerator` can only be initialized with a `StructType`, a `MapType` or a `ArrayType`. + require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType] + || dataType.isInstanceOf[ArrayType], + s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString}, " + + s"${MapType.simpleString} or ${ArrayType.simpleString} but got ${dataType.catalogString}") + + // `ValueWriter`s for all fields of the schema + private lazy val rootFieldWriters: Array[ValueWriter] = dataType match { + case st: StructType => st.map(_.dataType).map(makeWriter).toArray + case _ => throw QueryExecutionErrors.initialTypeNotTargetDataTypeError( + dataType, StructType.simpleString) + } + + // `ValueWriter` for array data storing rows of the schema. + private lazy val arrElementWriter: ValueWriter = dataType match { + case at: ArrayType => makeWriter(at.elementType) + case _: StructType | _: MapType => makeWriter(dataType) + case _ => throw QueryExecutionErrors.initialTypeNotTargetDataTypesError(dataType) + } + + private lazy val mapElementWriter: ValueWriter = dataType match { + case mt: MapType => makeWriter(mt.valueType) + case _ => throw QueryExecutionErrors.initialTypeNotTargetDataTypeError( + dataType, MapType.simpleString) + } + + private val gen = { + if (options.pretty) generator.setPrettyPrinter(new DefaultPrettyPrinter("")) else generator + } + + private val lineSeparator: String = options.lineSeparatorInWrite + + private val timestampFormatter = TimestampFormatter( + options.timestampFormatInWrite, + options.zoneId, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = false) + private val timestampNTZFormatter = TimestampFormatter( + options.timestampNTZFormatInWrite, + options.zoneId, + legacyFormat = FAST_DATE_FORMAT, + isParsing = false, + forTimestampNTZ = true) + private val dateFormatter = DateFormatter( + options.dateFormatInWrite, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = false) + + private def makeWriter(dataType: DataType): ValueWriter = dataType match { + case NullType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNull() + + case BooleanType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeBoolean(row.getBoolean(ordinal)) + + case ByteType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getByte(ordinal)) + + case ShortType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getShort(ordinal)) + + case IntegerType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getInt(ordinal)) + + case LongType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getLong(ordinal)) + + case FloatType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getFloat(ordinal)) + + case DoubleType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getDouble(ordinal)) + + case StringType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeString(row.getUTF8String(ordinal).toString) + + case TimestampType => + (row: SpecializedGetters, ordinal: Int) => + val timestampString = timestampFormatter.format(row.getLong(ordinal)) + gen.writeString(timestampString) + + case TimestampNTZType => + (row: SpecializedGetters, ordinal: Int) => + val timestampString = + timestampNTZFormatter.format(DateTimeUtils.microsToLocalDateTime(row.getLong(ordinal))) + gen.writeString(timestampString) + + case DateType => + (row: SpecializedGetters, ordinal: Int) => + val dateString = dateFormatter.format(row.getInt(ordinal)) + gen.writeString(dateString) + + case CalendarIntervalType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeString(row.getInterval(ordinal).toString) + + case YearMonthIntervalType(start, end) => + (row: SpecializedGetters, ordinal: Int) => + val ymString = IntervalUtils.toYearMonthIntervalString( + row.getInt(ordinal), + IntervalStringStyles.ANSI_STYLE, + start, + end) + gen.writeString(ymString) + + case DayTimeIntervalType(start, end) => + (row: SpecializedGetters, ordinal: Int) => + val dtString = IntervalUtils.toDayTimeIntervalString( + row.getLong(ordinal), + IntervalStringStyles.ANSI_STYLE, + start, + end) + gen.writeString(dtString) + + case BinaryType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeBinary(row.getBinary(ordinal)) + + case dt: DecimalType => + (row: SpecializedGetters, ordinal: Int) => + gen.writeNumber(row.getDecimal(ordinal, dt.precision, dt.scale).toJavaBigDecimal) + + case st: StructType => + val fieldWriters = st.map(_.dataType).map(makeWriter) + (row: SpecializedGetters, ordinal: Int) => + writeObject(writeFields(row.getStruct(ordinal, st.length), st, fieldWriters)) + + case at: ArrayType => + val elementWriter = makeWriter(at.elementType) + (row: SpecializedGetters, ordinal: Int) => + writeArray(writeArrayData(row.getArray(ordinal), elementWriter)) + + case mt: MapType => + val valueWriter = makeWriter(mt.valueType) + (row: SpecializedGetters, ordinal: Int) => + writeObject(writeMapData(row.getMap(ordinal), mt, valueWriter)) + + // For UDT values, they should be in the SQL type's corresponding value type. + // We should not see values in the user-defined class at here. + // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is + // an ArrayData at here, instead of a Vector. + case t: UserDefinedType[_] => + makeWriter(t.sqlType) + + case _ => + (row: SpecializedGetters, ordinal: Int) => + val v = row.get(ordinal, dataType) + throw QueryExecutionErrors.failToConvertValueToJsonError(v, v.getClass, dataType) + } + + private def writeObject(f: => Unit): Unit = { + gen.writeStartObject() + f + gen.writeEndObject() + } + + private def writeFields( + row: InternalRow, schema: StructType, fieldWriters: Seq[ValueWriter]): Unit = { + var i = 0 + while (i < row.numFields) { + val field = schema(i) + if (!row.isNullAt(i)) { + gen.writeFieldName(field.name) + fieldWriters(i).apply(row, i) + } else if ((!options.ignoreNullFields || + (options.writeNullIfWithDefaultValue && field.getExistenceDefaultValue().isDefined)) && field.name != "_key") { + gen.writeFieldName(field.name) + gen.writeNull() + } + i += 1 + } + } + + private def writeArray(f: => Unit): Unit = { + gen.writeStartArray() + f + gen.writeEndArray() + } + + private def writeArrayData( + array: ArrayData, fieldWriter: ValueWriter): Unit = { + var i = 0 + while (i < array.numElements()) { + if (!array.isNullAt(i)) { + fieldWriter.apply(array, i) + } else { + gen.writeNull() + } + i += 1 + } + } + + private def writeMapData( + map: MapData, mapType: MapType, fieldWriter: ValueWriter): Unit = { + val keyArray = map.keyArray() + val valueArray = map.valueArray() + var i = 0 + while (i < map.numElements()) { + gen.writeFieldName(keyArray.get(i, mapType.keyType).toString) + if (!valueArray.isNullAt(i)) { + fieldWriter.apply(valueArray, i) + } else { + gen.writeNull() + } + i += 1 + } + } + + def close(): Unit = gen.close() + + def flush(): Unit = gen.flush() + + def writeStartArray(): Unit = { + gen.writeStartArray() + } + + def writeEndArray(): Unit = { + gen.writeEndArray() + } + + /** + * Transforms a single `InternalRow` to JSON object using Jackson. + * This api calling will be validated through accessing `rootFieldWriters`. + * + * @param row The row to convert + */ + def write(row: InternalRow): Unit = { + writeObject(writeFields( + fieldWriters = rootFieldWriters, + row = row, + schema = dataType.asInstanceOf[StructType])) + } + + /** + * Transforms multiple `InternalRow`s or `MapData`s to JSON array using Jackson + * + * @param array The array of rows or maps to convert + */ + def write(array: ArrayData): Unit = writeArray(writeArrayData(array, arrElementWriter)) + + /** + * Transforms a single `MapData` to JSON object using Jackson + * This api calling will will be validated through accessing `mapElementWriter`. + * + * @param map a map to convert + */ + def write(map: MapData): Unit = { + writeObject(writeMapData( + fieldWriter = mapElementWriter, + map = map, + mapType = dataType.asInstanceOf[MapType])) + } + + def writeLineEnding(): Unit = { + // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + gen.writeRaw(lineSeparator) + } +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonParser.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonParser.scala new file mode 100644 index 00000000..cdafc938 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonParser.scala @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off + +package org.apache.spark.sql.arangodb.datasource.mapping.json + +import com.fasterxml.jackson.core._ +import org.apache.spark.SparkUpgradeException +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.Utils + +import java.io.{ByteArrayOutputStream, CharConversionException} +import java.nio.charset.MalformedInputException +import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal + +/** + * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. + */ +class JacksonParser( + schema: DataType, + val options: JSONOptions, + allowArrayAsStructs: Boolean = false, + filters: Seq[Filter] = Seq.empty) extends Logging { + + import JacksonUtils._ + import com.fasterxml.jackson.core.JsonToken._ + + // A `ValueConverter` is responsible for converting a value from `JsonParser` + // to a value in a field for `InternalRow`. + private type ValueConverter = JsonParser => AnyRef + + // `ValueConverter`s for the root schema for all fields in the schema + private val rootConverter = makeRootConverter(schema) + + private val factory = options.buildJsonFactory() + + private lazy val timestampFormatter = TimestampFormatter( + options.timestampFormatInRead, + options.zoneId, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true) + private lazy val timestampNTZFormatter = TimestampFormatter( + options.timestampNTZFormatInRead, + options.zoneId, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true, + forTimestampNTZ = true) + private lazy val dateFormatter = DateFormatter( + options.dateFormatInRead, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true) + + // Flags to signal if we need to fall back to the backward compatible behavior of parsing + // dates and timestamps. + // For more information, see comments for "enableDateTimeParsingFallback" option in JSONOptions. + private val enableParsingFallbackForTimestampType = + options.enableDateTimeParsingFallback + .orElse(SQLConf.get.jsonEnableDateTimeParsingFallback) + .getOrElse { + SQLConf.get.legacyTimeParserPolicy == SQLConf.LegacyBehaviorPolicy.LEGACY || + options.timestampFormatInRead.isEmpty + } + private val enableParsingFallbackForDateType = + options.enableDateTimeParsingFallback + .orElse(SQLConf.get.jsonEnableDateTimeParsingFallback) + .getOrElse { + SQLConf.get.legacyTimeParserPolicy == SQLConf.LegacyBehaviorPolicy.LEGACY || + options.dateFormatInRead.isEmpty + } + + private val enablePartialResults = SQLConf.get.jsonEnablePartialResults + + /** + * Create a converter which converts the JSON documents held by the `JsonParser` + * to a value according to a desired schema. This is a wrapper for the method + * `makeConverter()` to handle a row wrapped with an array. + */ + private def makeRootConverter(dt: DataType): JsonParser => Iterable[InternalRow] = { + dt match { + case st: StructType => makeStructRootConverter(st) + case mt: MapType => makeMapRootConverter(mt) + case at: ArrayType => makeArrayRootConverter(at) + } + } + + private def makeStructRootConverter(st: StructType): JsonParser => Iterable[InternalRow] = { + val elementConverter = makeConverter(st) + val fieldConverters = st.map(_.dataType).map(makeConverter).toArray + val jsonFilters = if (SQLConf.get.jsonFilterPushDown) { + new JsonFilters(filters, st) + } else { + new NoopFilters + } + (parser: JsonParser) => parseJsonToken[Iterable[InternalRow]](parser, st) { + case START_OBJECT => convertObject(parser, st, fieldConverters, jsonFilters, isRoot = true) + // SPARK-3308: support reading top level JSON arrays and take every element + // in such an array as a row + // + // For example, we support, the JSON data as below: + // + // [{"a":"str_a_1"}] + // [{"a":"str_a_2"}, {"b":"str_b_3"}] + // + // resulting in: + // + // List([str_a_1,null]) + // List([str_a_2,null], [null,str_b_3]) + // + case START_ARRAY if allowArrayAsStructs => + val array = convertArray(parser, elementConverter, isRoot = true) + // Here, as we support reading top level JSON arrays and take every element + // in such an array as a row, this case is possible. + if (array.numElements() == 0) { + Array.empty[InternalRow] + } else { + array.toArray[InternalRow](schema) + } + case START_ARRAY => + throw QueryExecutionErrors.cannotParseJsonArraysAsStructsError() + } + } + + private def makeMapRootConverter(mt: MapType): JsonParser => Iterable[InternalRow] = { + val fieldConverter = makeConverter(mt.valueType) + (parser: JsonParser) => parseJsonToken[Iterable[InternalRow]](parser, mt) { + case START_OBJECT => Some(InternalRow(convertMap(parser, fieldConverter))) + } + } + + private def makeArrayRootConverter(at: ArrayType): JsonParser => Iterable[InternalRow] = { + val elemConverter = makeConverter(at.elementType) + (parser: JsonParser) => parseJsonToken[Iterable[InternalRow]](parser, at) { + case START_ARRAY => Some(InternalRow(convertArray(parser, elemConverter))) + case START_OBJECT if at.elementType.isInstanceOf[StructType] => + // This handles the case when an input JSON object is a structure but + // the specified schema is an array of structures. In that case, the input JSON is + // considered as an array of only one element of struct type. + // This behavior was introduced by changes for SPARK-19595. + // + // For example, if the specified schema is ArrayType(new StructType().add("i", IntegerType)) + // and JSON input as below: + // + // [{"i": 1}, {"i": 2}] + // [{"i": 3}] + // {"i": 4} + // + // The last row is considered as an array with one element, and result of conversion: + // + // Seq(Row(1), Row(2)) + // Seq(Row(3)) + // Seq(Row(4)) + // + val st = at.elementType.asInstanceOf[StructType] + val fieldConverters = st.map(_.dataType).map(makeConverter).toArray + Some(InternalRow(new GenericArrayData(convertObject(parser, st, fieldConverters).toArray))) + } + } + + private val decimalParser = ExprUtils.getDecimalParser(options.locale) + + /** + * Create a converter which converts the JSON documents held by the `JsonParser` + * to a value according to a desired schema. + */ + def makeConverter(dataType: DataType): ValueConverter = dataType match { + case BooleanType => + (parser: JsonParser) => parseJsonToken[java.lang.Boolean](parser, dataType) { + case VALUE_TRUE => true + case VALUE_FALSE => false + } + + case ByteType => + (parser: JsonParser) => parseJsonToken[java.lang.Byte](parser, dataType) { + case VALUE_NUMBER_INT => parser.getByteValue + } + + case ShortType => + (parser: JsonParser) => parseJsonToken[java.lang.Short](parser, dataType) { + case VALUE_NUMBER_INT => parser.getShortValue + } + + case IntegerType => + (parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) { + case VALUE_NUMBER_INT => parser.getIntValue + } + + case LongType => + (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) { + case VALUE_NUMBER_INT => parser.getLongValue + } + + case FloatType => + (parser: JsonParser) => parseJsonToken[java.lang.Float](parser, dataType) { + case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => + parser.getFloatValue + + case VALUE_STRING if parser.getTextLength >= 1 => + // Special case handling for NaN and Infinity. + parser.getText match { + case "NaN" if options.allowNonNumericNumbers => + Float.NaN + case "+INF" | "+Infinity" | "Infinity" if options.allowNonNumericNumbers => + Float.PositiveInfinity + case "-INF" | "-Infinity" if options.allowNonNumericNumbers => + Float.NegativeInfinity + case _ => throw QueryExecutionErrors.cannotParseStringAsDataTypeError( + parser, VALUE_STRING, FloatType) + } + } + + case DoubleType => + (parser: JsonParser) => parseJsonToken[java.lang.Double](parser, dataType) { + case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => + parser.getDoubleValue + + case VALUE_STRING if parser.getTextLength >= 1 => + // Special case handling for NaN and Infinity. + parser.getText match { + case "NaN" if options.allowNonNumericNumbers => + Double.NaN + case "+INF" | "+Infinity" | "Infinity" if options.allowNonNumericNumbers => + Double.PositiveInfinity + case "-INF" | "-Infinity" if options.allowNonNumericNumbers => + Double.NegativeInfinity + case _ => throw QueryExecutionErrors.cannotParseStringAsDataTypeError( + parser, VALUE_STRING, DoubleType) + } + } + + case StringType => + (parser: JsonParser) => parseJsonToken[UTF8String](parser, dataType) { + case VALUE_STRING => + UTF8String.fromString(parser.getText) + + case _ => + // Note that it always tries to convert the data as string without the case of failure. + val writer = new ByteArrayOutputStream() + Utils.tryWithResource(factory.createGenerator(writer, JsonEncoding.UTF8)) { + generator => generator.copyCurrentStructure(parser) + } + UTF8String.fromBytes(writer.toByteArray) + } + + case TimestampType => + (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) { + case VALUE_STRING if parser.getTextLength >= 1 => + try { + timestampFormatter.parse(parser.getText) + } catch { + case NonFatal(e) => + // If fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility if enabled. + if (!enableParsingFallbackForTimestampType) { + throw e + } + val str = DateTimeUtils.cleanLegacyTimestampStr(UTF8String.fromString(parser.getText)) + DateTimeUtils.stringToTimestamp(str, options.zoneId).getOrElse(throw e) + } + + case VALUE_NUMBER_INT => + parser.getLongValue * 1000L + } + + case TimestampNTZType => + (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) { + case VALUE_STRING if parser.getTextLength >= 1 => + timestampNTZFormatter.parseWithoutTimeZone(parser.getText, false) + } + + case DateType => + (parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) { + case VALUE_STRING if parser.getTextLength >= 1 => + try { + dateFormatter.parse(parser.getText) + } catch { + case NonFatal(e) => + // If fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility if enabled. + if (!enableParsingFallbackForDateType) { + throw e + } + val str = DateTimeUtils.cleanLegacyTimestampStr(UTF8String.fromString(parser.getText)) + DateTimeUtils.stringToDate(str).getOrElse { + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + try { + RebaseDateTime.rebaseJulianToGregorianDays(parser.getText.toInt) + } catch { + case _: NumberFormatException => throw e + } + }.asInstanceOf[Integer] + } + } + + case BinaryType => + (parser: JsonParser) => parseJsonToken[Array[Byte]](parser, dataType) { + case VALUE_STRING => parser.getBinaryValue + } + + case dt: DecimalType => + (parser: JsonParser) => parseJsonToken[Decimal](parser, dataType) { + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) => + Decimal(parser.getDecimalValue, dt.precision, dt.scale) + case VALUE_STRING if parser.getTextLength >= 1 => + val bigDecimal = decimalParser(parser.getText) + Decimal(bigDecimal, dt.precision, dt.scale) + } + + case CalendarIntervalType => (parser: JsonParser) => + parseJsonToken[CalendarInterval](parser, dataType) { + case VALUE_STRING => + IntervalUtils.safeStringToInterval(UTF8String.fromString(parser.getText)) + } + + case ym: YearMonthIntervalType => (parser: JsonParser) => + parseJsonToken[Integer](parser, dataType) { + case VALUE_STRING => + val expr = Cast(Literal(parser.getText), ym) + Integer.valueOf(expr.eval(EmptyRow).asInstanceOf[Int]) + } + + case dt: DayTimeIntervalType => (parser: JsonParser) => + parseJsonToken[java.lang.Long](parser, dataType) { + case VALUE_STRING => + val expr = Cast(Literal(parser.getText), dt) + java.lang.Long.valueOf(expr.eval(EmptyRow).asInstanceOf[Long]) + } + + case st: StructType => + val fieldConverters = st.map(_.dataType).map(makeConverter).toArray + (parser: JsonParser) => parseJsonToken[InternalRow](parser, dataType) { + case START_OBJECT => convertObject(parser, st, fieldConverters).get + } + + case at: ArrayType => + val elementConverter = makeConverter(at.elementType) + (parser: JsonParser) => parseJsonToken[ArrayData](parser, dataType) { + case START_ARRAY => convertArray(parser, elementConverter) + } + + case mt: MapType => + val valueConverter = makeConverter(mt.valueType) + (parser: JsonParser) => parseJsonToken[MapData](parser, dataType) { + case START_OBJECT => convertMap(parser, valueConverter) + } + + case udt: UserDefinedType[_] => + makeConverter(udt.sqlType) + + case _: NullType => + (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) { + case _ => null + } + + // We don't actually hit this exception though, we keep it for understandability + case _ => throw QueryExecutionErrors.unsupportedTypeError(dataType) + } + + /** + * This method skips `FIELD_NAME`s at the beginning, and handles nulls ahead before trying + * to parse the JSON token using given function `f`. If the `f` failed to parse and convert the + * token, call `failedConversion` to handle the token. + */ + @scala.annotation.tailrec + private def parseJsonToken[R >: Null]( + parser: JsonParser, + dataType: DataType)(f: PartialFunction[JsonToken, R]): R = { + parser.getCurrentToken match { + case FIELD_NAME => + // There are useless FIELD_NAMEs between START_OBJECT and END_OBJECT tokens + parser.nextToken() + parseJsonToken[R](parser, dataType)(f) + + case null | VALUE_NULL => null + + case other => f.applyOrElse(other, failedConversion(parser, dataType)) + } + } + + private val allowEmptyString = SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_EMPTY_STRING_IN_JSON) + + /** + * This function throws an exception for failed conversion. For empty string on data types + * except for string and binary types, this also throws an exception. + */ + private def failedConversion[R >: Null]( + parser: JsonParser, + dataType: DataType): PartialFunction[JsonToken, R] = { + + // SPARK-25040: Disallows empty strings for data types except for string and binary types. + // But treats empty strings as null for certain types if the legacy config is enabled. + case VALUE_STRING if parser.getTextLength < 1 && allowEmptyString => + dataType match { + case FloatType | DoubleType | TimestampType | DateType => + throw QueryExecutionErrors.emptyJsonFieldValueError(dataType) + case _ => null + } + + case VALUE_STRING if parser.getTextLength < 1 => + throw QueryExecutionErrors.emptyJsonFieldValueError(dataType) + + case token => + // We cannot parse this token based on the given data type. So, we throw a + // RuntimeException and this exception will be caught by `parse` method. + throw QueryExecutionErrors.cannotParseJSONFieldError(parser, token, dataType) + } + + /** + * Parse an object from the token stream into a new Row representing the schema. + * Fields in the json that are not defined in the requested schema will be dropped. + */ + private def convertObject( + parser: JsonParser, + schema: StructType, + fieldConverters: Array[ValueConverter], + structFilters: StructFilters = new NoopFilters(), + isRoot: Boolean = false): Option[InternalRow] = { + val row = new GenericInternalRow(schema.length) + var badRecordException: Option[Throwable] = None + var skipRow = false + + structFilters.reset() + resetExistenceDefaultsBitmask(schema) + while (!skipRow && nextUntil(parser, JsonToken.END_OBJECT)) { + schema.getFieldIndex(parser.getCurrentName) match { + case Some(index) => + try { + row.update(index, fieldConverters(index).apply(parser)) + skipRow = structFilters.skipRow(row, index) + schema.existenceDefaultsBitmask(index) = false + } catch { + case e: SparkUpgradeException => throw e + case NonFatal(e) if isRoot || enablePartialResults => + badRecordException = badRecordException.orElse(Some(e)) + parser.skipChildren() + } + case None => + parser.skipChildren() + } + } + if (skipRow) { + None + } else if (badRecordException.isEmpty) { + applyExistenceDefaultValuesToRow(schema, row) + Some(row) + } else { + throw PartialResultException(row, badRecordException.get) + } + } + + /** + * Parse an object as a Map, preserving all fields. + */ + private def convertMap( + parser: JsonParser, + fieldConverter: ValueConverter): MapData = { + val keys = ArrayBuffer.empty[UTF8String] + val values = ArrayBuffer.empty[Any] + var badRecordException: Option[Throwable] = None + + while (nextUntil(parser, JsonToken.END_OBJECT)) { + keys += UTF8String.fromString(parser.getCurrentName) + try { + values += fieldConverter.apply(parser) + } catch { + case PartialResultException(row, cause) if enablePartialResults => + badRecordException = badRecordException.orElse(Some(cause)) + values += row + case NonFatal(e) if enablePartialResults => + badRecordException = badRecordException.orElse(Some(e)) + parser.skipChildren() + } + } + + // The JSON map will never have null or duplicated map keys, it's safe to create a + // ArrayBasedMapData directly here. + val mapData = ArrayBasedMapData(keys.toArray, values.toArray) + + if (badRecordException.isEmpty) { + mapData + } else { + throw PartialResultException(InternalRow(mapData), badRecordException.get) + } + } + + /** + * Parse an object as a Array. + */ + private def convertArray( + parser: JsonParser, + fieldConverter: ValueConverter, + isRoot: Boolean = false): ArrayData = { + val values = ArrayBuffer.empty[Any] + var badRecordException: Option[Throwable] = None + + while (nextUntil(parser, JsonToken.END_ARRAY)) { + try { + val v = fieldConverter.apply(parser) + if (isRoot && v == null) throw QueryExecutionErrors.rootConverterReturnNullError() + values += v + } catch { + case PartialResultException(row, cause) if enablePartialResults => + badRecordException = badRecordException.orElse(Some(cause)) + values += row + } + } + + val arrayData = new GenericArrayData(values.toArray) + + if (badRecordException.isEmpty) { + arrayData + } else { + throw PartialResultException(InternalRow(arrayData), badRecordException.get) + } + } + + /** + * Parse the JSON input to the set of [[InternalRow]]s. + * + * @param recordLiteral an optional function that will be used to generate + * the corrupt record text instead of record.toString + */ + def parse[T]( + record: T, + createParser: (JsonFactory, T) => JsonParser, + recordLiteral: T => UTF8String): Iterable[InternalRow] = { + try { + Utils.tryWithResource(createParser(factory, record)) { parser => + // a null first token is equivalent to testing for input.trim.isEmpty + // but it works on any token stream and not just strings + parser.nextToken() match { + case null => None + case _ => rootConverter.apply(parser) match { + case null => throw QueryExecutionErrors.rootConverterReturnNullError() + case rows => rows.toSeq + } + } + } + } catch { + case e: SparkUpgradeException => throw e + case e @ (_: RuntimeException | _: JsonProcessingException | _: MalformedInputException) => + // JSON parser currently doesn't support partial results for corrupted records. + // For such records, all fields other than the field configured by + // `columnNameOfCorruptRecord` are set to `null`. + throw BadRecordException(() => recordLiteral(record), () => None, e) + case e: CharConversionException if options.encoding.isEmpty => + val msg = + """JSON parser cannot handle a character in its input. + |Specifying encoding as an input option explicitly might help to resolve the issue. + |""".stripMargin + e.getMessage + val wrappedCharException = new CharConversionException(msg) + wrappedCharException.initCause(e) + throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException) + case PartialResultException(row, cause) => + throw BadRecordException( + record = () => recordLiteral(record), + partialResult = () => Some(row), + cause) + } + } +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonUtils.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonUtils.scala new file mode 100644 index 00000000..122800b0 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JacksonUtils.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off + +package org.apache.spark.sql.arangodb.datasource.mapping.json + +import com.fasterxml.jackson.core.{JsonParser, JsonToken} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} +import org.apache.spark.sql.errors.QueryErrorsBase +import org.apache.spark.sql.types._ + +object JacksonUtils extends QueryErrorsBase { + /** + * Advance the parser until a null or a specific token is found + */ + def nextUntil(parser: JsonParser, stopOn: JsonToken): Boolean = { + parser.nextToken() match { + case null => false + case x => x != stopOn + } + } + + def verifyType(name: String, dataType: DataType): TypeCheckResult = { + dataType match { + case NullType | _: AtomicType | CalendarIntervalType => TypeCheckSuccess + + case st: StructType => + st.foldLeft(TypeCheckSuccess: TypeCheckResult) { case (currResult, field) => + if (currResult.isFailure) currResult else verifyType(field.name, field.dataType) + } + + case at: ArrayType => verifyType(name, at.elementType) + + // For MapType, its keys are treated as a string (i.e. calling `toString`) basically when + // generating JSON, so we only care if the values are valid for JSON. + case mt: MapType => verifyType(name, mt.valueType) + + case udt: UserDefinedType[_] => verifyType(name, udt.sqlType) + + case _ => + DataTypeMismatch( + errorSubClass = "CANNOT_CONVERT_TO_JSON", + messageParameters = Map( + "name" -> toSQLId(name), + "type" -> toSQLType(dataType))) + } + } +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonFilters.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonFilters.scala new file mode 100644 index 00000000..1665787a --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonFilters.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off + +package org.apache.spark.sql.arangodb.datasource.mapping.json + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.{InternalRow, StructFilters} +import org.apache.spark.sql.sources +import org.apache.spark.sql.types.StructType + +/** + * The class provides API for applying pushed down source filters to rows with + * a struct schema parsed from JSON records. The class should be used in this way: + * 1. Before processing of the next row, `JacksonParser` (parser for short) resets the internal + * state of `JsonFilters` by calling the `reset()` method. + * 2. The parser reads JSON fields one-by-one in streaming fashion. It converts an incoming + * field value to the desired type from the schema. After that, it sets the value to an instance + * of `InternalRow` at the position according to the schema. Order of parsed JSON fields can + * be different from the order in the schema. + * 3. Per every JSON field of the top-level JSON object, the parser calls `skipRow` by passing + * an `InternalRow` in which some of fields can be already set, and the position of the JSON + * field according to the schema. + * 3.1 `skipRow` finds a group of predicates that refers to this JSON field. + * 3.2 Per each predicate from the group, `skipRow` decrements its reference counter. + * 3.2.1 If predicate reference counter becomes 0, it means that all predicate attributes have + * been already set in the internal row, and the predicate can be applied to it. `skipRow` + * invokes the predicate for the row. + * 3.3 `skipRow` applies predicates until one of them returns `false`. In that case, the method + * returns `true` to the parser. + * 3.4 If all predicates with zero reference counter return `true`, the final result of + * the method is `false` which tells the parser to not skip the row. + * 4. If the parser gets `true` from `JsonFilters.skipRow`, it must not call the method anymore + * for this internal row, and should go the step 1. + * + * Besides of `StructFilters` assumptions, `JsonFilters` assumes that: + * - `skipRow()` can be called for any valid index of the struct fields, + * and in any order. + * - After `skipRow()` returns `true`, the internal state of `JsonFilters` can be inconsistent, + * so, `skipRow()` must not be called for the current row anymore without `reset()`. + * + * @param pushedFilters The pushed down source filters. The filters should refer to + * the fields of the provided schema. + * @param schema The required schema of records from datasource files. + */ +class JsonFilters(pushedFilters: Seq[sources.Filter], schema: StructType) + extends StructFilters(pushedFilters, schema) { + + /** + * Stateful JSON predicate that keeps track of its dependent references in the + * current row via `refCount`. + * + * @param predicate The predicate compiled from pushed down source filters. + * @param totalRefs The total amount of all filters references which the predicate + * compiled from. + */ + case class JsonPredicate(predicate: BasePredicate, totalRefs: Int) { + // The current number of predicate references in the row that have been not set yet. + // When `refCount` reaches zero, the predicate has all dependencies are set, and can + // be applied to the row. + var refCount: Int = totalRefs + + def reset(): Unit = { + refCount = totalRefs + } + } + + // Predicates compiled from the pushed down filters. The predicates are grouped by their + // attributes. The i-th group contains predicates that refer to the i-th field of the given + // schema. A predicates can be placed to many groups if it has many attributes. For example: + // schema: i INTEGER, s STRING + // filters: IsNotNull("i"), AlwaysTrue, Or(EqualTo("i", 0), StringStartsWith("s", "abc")) + // predicates: + // 0: Array(IsNotNull("i"), AlwaysTrue, Or(EqualTo("i", 0), StringStartsWith("s", "abc"))) + // 1: Array(AlwaysTrue, Or(EqualTo("i", 0), StringStartsWith("s", "abc"))) + private val predicates: Array[Array[JsonPredicate]] = { + val groupedPredicates = Array.fill(schema.length)(Array.empty[JsonPredicate]) + val groupedByRefSet: Map[Set[String], JsonPredicate] = filters + // Group filters that have the same set of references. For example: + // IsNotNull("i") -> Set("i"), AlwaysTrue -> Set(), + // Or(EqualTo("i", 0), StringStartsWith("s", "abc")) -> Set("i", "s") + // By grouping filters we could avoid tracking their state of references in the + // current row separately. + .groupBy(_.references.toSet) + // Combine all filters from the same group by `And` because all filters should + // return `true` to do not skip a row. The result is compiled to a predicate. + .map { case (refSet, refsFilters) => + (refSet, JsonPredicate(toPredicate(refsFilters), refSet.size)) + } + // Apply predicates w/o references like `AlwaysTrue` and `AlwaysFalse` to all fields. + // We cannot set such predicates to a particular position because skipRow() can + // be invoked for any index due to unpredictable order of JSON fields in JSON records. + val withLiterals: Map[Set[String], JsonPredicate] = groupedByRefSet.map { + case (refSet, pred) if refSet.isEmpty => + (schema.fields.map(_.name).toSet, pred.copy(totalRefs = 1)) + case others => others + } + // Build a map where key is only one field and value is seq of predicates refer to the field + // "i" -> Seq(AlwaysTrue, IsNotNull("i"), Or(EqualTo("i", 0), StringStartsWith("s", "abc"))) + // "s" -> Seq(AlwaysTrue, Or(EqualTo("i", 0), StringStartsWith("s", "abc"))) + val groupedByFields: Map[String, Seq[(String, JsonPredicate)]] = withLiterals.toSeq + .flatMap { case (refSet, pred) => refSet.map((_, pred)) } + .groupBy(_._1) + // Build the final array by converting keys of `groupedByFields` to their + // indexes in the provided schema. + groupedByFields.foreach { case (fieldName, fieldPredicates) => + val fieldIndex = schema.fieldIndex(fieldName) + groupedPredicates(fieldIndex) = fieldPredicates.map(_._2).toArray + } + groupedPredicates + } + + /** + * Applies predicates (compiled filters) associated with the row field value + * at the position `index` only if other predicates dependencies are already + * set in the given row. + * + * Note: If the function returns `true`, `refCount` of some predicates can be not decremented. + * + * @param row The row with fully or partially set values. + * @param index The index of already set value. + * @return `true` if at least one of applicable predicates (all dependent row values are set) + * return `false`. It returns `false` if all predicates return `true`. + */ + def skipRow(row: InternalRow, index: Int): Boolean = { + assert(0 <= index && index < schema.fields.length, + s"The index $index is out of the valid range [0, ${schema.fields.length}). " + + s"It must point out to a field of the schema: ${schema.catalogString}.") + var skip = false + for (pred <- predicates(index) if !skip) { + pred.refCount -= 1 + assert(pred.refCount >= 0, + s"Predicate reference counter cannot be negative but got ${pred.refCount}.") + skip = pred.refCount == 0 && !pred.predicate.eval(row) + } + skip + } + + /** + * Reset states of all predicates by re-initializing reference counters. + */ + override def reset(): Unit = predicates.foreach(_.foreach(_.reset)) +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonInferSchema.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonInferSchema.scala new file mode 100644 index 00000000..d2c70f71 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/json/JsonInferSchema.scala @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off + +package org.apache.spark.sql.arangodb.datasource.mapping.json + +import com.fasterxml.jackson.core._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.TypeCoercion +import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +import java.io.CharConversionException +import java.nio.charset.MalformedInputException +import java.util.Comparator +import scala.util.control.Exception.allCatch + +private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { + + private val decimalParser = ExprUtils.getDecimalParser(options.locale) + + private val timestampFormatter = TimestampFormatter( + options.timestampFormatInRead, + options.zoneId, + options.locale, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true) + private val timestampNTZFormatter = TimestampFormatter( + options.timestampNTZFormatInRead, + options.zoneId, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true, + forTimestampNTZ = true) + + private def handleJsonErrorsByParseMode(parseMode: ParseMode, + columnNameOfCorruptRecord: String, e: Throwable): Option[StructType] = { + parseMode match { + case PermissiveMode => + Some(StructType(Array(StructField(columnNameOfCorruptRecord, StringType)))) + case DropMalformedMode => + None + case FailFastMode => + throw QueryExecutionErrors.malformedRecordsDetectedInSchemaInferenceError(e) + } + } + + /** + * Infer the type of a collection of json records in three stages: + * 1. Infer the type of each record + * 2. Merge types by choosing the lowest type necessary to cover equal keys + * 3. Replace any remaining null fields with string, the top type + */ + def infer[T]( + json: RDD[T], + createParser: (JsonFactory, T) => JsonParser): StructType = { + val parseMode = options.parseMode + val columnNameOfCorruptRecord = options.columnNameOfCorruptRecord + + // In each RDD partition, perform schema inference on each row and merge afterwards. + val typeMerger = JsonInferSchema.compatibleRootType(columnNameOfCorruptRecord, parseMode) + val mergedTypesFromPartitions = json.mapPartitions { iter => + val factory = options.buildJsonFactory() + iter.flatMap { row => + try { + Utils.tryWithResource(createParser(factory, row)) { parser => + parser.nextToken() + Some(inferField(parser)) + } + } catch { + case e @ (_: RuntimeException | _: JsonProcessingException | + _: MalformedInputException) => + handleJsonErrorsByParseMode(parseMode, columnNameOfCorruptRecord, e) + case e: CharConversionException if options.encoding.isEmpty => + val msg = + """JSON parser cannot handle a character in its input. + |Specifying encoding as an input option explicitly might help to resolve the issue. + |""".stripMargin + e.getMessage + val wrappedCharException = new CharConversionException(msg) + wrappedCharException.initCause(e) + handleJsonErrorsByParseMode(parseMode, columnNameOfCorruptRecord, wrappedCharException) + } + }.reduceOption(typeMerger).iterator + } + + // Here we manually submit a fold-like Spark job, so that we can set the SQLConf when running + // the fold functions in the scheduler event loop thread. + val existingConf = SQLConf.get + var rootType: DataType = StructType(Nil) + val foldPartition = (iter: Iterator[DataType]) => iter.fold(StructType(Nil))(typeMerger) + val mergeResult = (index: Int, taskResult: DataType) => { + rootType = SQLConf.withExistingConf(existingConf) { + typeMerger(rootType, taskResult) + } + } + json.sparkContext.runJob(mergedTypesFromPartitions, foldPartition, mergeResult) + + canonicalizeType(rootType, options) + .find(_.isInstanceOf[StructType]) + // canonicalizeType erases all empty structs, including the only one we want to keep + .getOrElse(StructType(Nil)).asInstanceOf[StructType] + } + + /** + * Infer the type of a json document from the parser's token stream + */ + def inferField(parser: JsonParser): DataType = { + import com.fasterxml.jackson.core.JsonToken._ + parser.getCurrentToken match { + case null | VALUE_NULL => NullType + + case FIELD_NAME => + parser.nextToken() + inferField(parser) + + case VALUE_STRING if parser.getTextLength < 1 => + // Zero length strings and nulls have special handling to deal + // with JSON generators that do not distinguish between the two. + // To accurately infer types for empty strings that are really + // meant to represent nulls we assume that the two are isomorphic + // but will defer treating null fields as strings until all the + // record fields' types have been combined. + NullType + + case VALUE_STRING => + val field = parser.getText + lazy val decimalTry = allCatch opt { + val bigDecimal = decimalParser(field) + DecimalType(bigDecimal.precision, bigDecimal.scale) + } + if (options.prefersDecimal && decimalTry.isDefined) { + decimalTry.get + } else if (options.inferTimestamp && + timestampNTZFormatter.parseWithoutTimeZoneOptional(field, false).isDefined) { + SQLConf.get.timestampType + } else if (options.inferTimestamp && + timestampFormatter.parseOptional(field).isDefined) { + TimestampType + } else { + StringType + } + + case START_OBJECT => + val builder = Array.newBuilder[StructField] + while (nextUntil(parser, END_OBJECT)) { + builder += StructField( + parser.getCurrentName, + inferField(parser), + nullable = true) + } + val fields: Array[StructField] = builder.result() + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(fields, JsonInferSchema.structFieldComparator) + StructType(fields) + + case START_ARRAY => + // If this JSON array is empty, we use NullType as a placeholder. + // If this array is not empty in other JSON objects, we can resolve + // the type as we pass through all JSON objects. + var elementType: DataType = NullType + while (nextUntil(parser, END_ARRAY)) { + elementType = JsonInferSchema.compatibleType( + elementType, inferField(parser)) + } + + ArrayType(elementType) + + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if options.primitivesAsString => StringType + + case (VALUE_TRUE | VALUE_FALSE) if options.primitivesAsString => StringType + + case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => + import JsonParser.NumberType._ + parser.getNumberType match { + // For Integer values, use LongType by default. + case INT | LONG => LongType + // Since we do not have a data type backed by BigInteger, + // when we see a Java BigInteger, we use DecimalType. + case BIG_INTEGER | BIG_DECIMAL => + val v = parser.getDecimalValue + if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) { + DecimalType(Math.max(v.precision(), v.scale()), v.scale()) + } else { + DoubleType + } + case FLOAT | DOUBLE if options.prefersDecimal => + val v = parser.getDecimalValue + if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) { + DecimalType(Math.max(v.precision(), v.scale()), v.scale()) + } else { + DoubleType + } + case FLOAT | DOUBLE => + DoubleType + } + + case VALUE_TRUE | VALUE_FALSE => BooleanType + + case _ => + throw QueryExecutionErrors.malformedJSONError() + } + } + + /** + * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields, + * drops NullTypes or converts them to StringType based on provided options. + */ + private[json] def canonicalizeType( + tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { + case at: ArrayType => + canonicalizeType(at.elementType, options) + .map(t => at.copy(elementType = t)) + + case StructType(fields) => + val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f => + canonicalizeType(f.dataType, options) + .map(t => f.copy(dataType = t)) + } + // SPARK-8093: empty structs should be deleted + if (canonicalFields.isEmpty) { + None + } else { + Some(StructType(canonicalFields)) + } + + case NullType => + if (options.dropFieldIfAllNull) { + None + } else { + Some(StringType) + } + + case other => Some(other) + } +} + +object JsonInferSchema { + val structFieldComparator = new Comparator[StructField] { + override def compare(o1: StructField, o2: StructField): Int = { + o1.name.compareTo(o2.name) + } + } + + def isSorted(arr: Array[StructField]): Boolean = { + var i: Int = 0 + while (i < arr.length - 1) { + if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) { + return false + } + i += 1 + } + true + } + + def withCorruptField( + struct: StructType, + other: DataType, + columnNameOfCorruptRecords: String, + parseMode: ParseMode): StructType = parseMode match { + case PermissiveMode => + // If we see any other data type at the root level, we get records that cannot be + // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { + // If this given struct does not have a column used for corrupt records, + // add this field. + val newFields: Array[StructField] = + StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(newFields, structFieldComparator) + StructType(newFields) + } else { + // Otherwise, just return this struct. + struct + } + + case DropMalformedMode => + // If corrupt record handling is disabled we retain the valid schema and discard the other. + struct + + case FailFastMode => + // If `other` is not struct type, consider it as malformed one and throws an exception. + throw QueryExecutionErrors.malformedRecordsDetectedInSchemaInferenceError(other) + } + + /** + * Remove top-level ArrayType wrappers and merge the remaining schemas + */ + def compatibleRootType( + columnNameOfCorruptRecords: String, + parseMode: ParseMode): (DataType, DataType) => DataType = { + // Since we support array of json objects at the top level, + // we need to check the element type and find the root level data type. + case (ArrayType(ty1, _), ty2) => + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + case (ty1, ArrayType(ty2, _)) => + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + // Discard null/empty documents + case (struct: StructType, NullType) => struct + case (NullType, struct: StructType) => struct + case (struct: StructType, o) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) + case (o, struct: StructType) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) + // If we get anything else, we call compatibleType. + // Usually, when we reach here, ty1 and ty2 are two StructTypes. + case (ty1, ty2) => compatibleType(ty1, ty2) + } + + private[this] val emptyStructFieldArray = Array.empty[StructField] + + /** + * Returns the most general data type for two given data types. + */ + def compatibleType(t1: DataType, t2: DataType): DataType = { + TypeCoercion.findTightestCommonType(t1, t2).getOrElse { + // t1 or t2 is a StructType, ArrayType, or an unexpected type. + (t1, t2) match { + // Double support larger range than fixed decimal, DecimalType.Maximum should be enough + // in most case, also have better precision. + case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => + DoubleType + + case (t1: DecimalType, t2: DecimalType) => + val scale = math.max(t1.scale, t2.scale) + val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) + if (range + scale > 38) { + // DecimalType can't support precision > 38 + DoubleType + } else { + DecimalType(range + scale, scale) + } + + case (StructType(fields1), StructType(fields2)) => + // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. + // Therefore, we can take advantage of the fact that we're merging sorted lists and skip + // building a hash map or performing additional sorting. + assert(isSorted(fields1), + s"${StructType.simpleString}'s fields were not sorted: ${fields1.toSeq}") + assert(isSorted(fields2), + s"${StructType.simpleString}'s fields were not sorted: ${fields2.toSeq}") + + val newFields = new java.util.ArrayList[StructField]() + + var f1Idx = 0 + var f2Idx = 0 + + while (f1Idx < fields1.length && f2Idx < fields2.length) { + val f1Name = fields1(f1Idx).name + val f2Name = fields2(f2Idx).name + val comp = f1Name.compareTo(f2Name) + if (comp == 0) { + val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) + newFields.add(StructField(f1Name, dataType, nullable = true)) + f1Idx += 1 + f2Idx += 1 + } else if (comp < 0) { // f1Name < f2Name + newFields.add(fields1(f1Idx)) + f1Idx += 1 + } else { // f1Name > f2Name + newFields.add(fields2(f2Idx)) + f2Idx += 1 + } + } + while (f1Idx < fields1.length) { + newFields.add(fields1(f1Idx)) + f1Idx += 1 + } + while (f2Idx < fields2.length) { + newFields.add(fields2(f2Idx)) + f2Idx += 1 + } + StructType(newFields.toArray(emptyStructFieldArray)) + + case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + + // The case that given `DecimalType` is capable of given `IntegralType` is handled in + // `findTightestCommonType`. Both cases below will be executed only when the given + // `DecimalType` is not capable of the given `IntegralType`. + case (t1: IntegralType, t2: DecimalType) => + compatibleType(DecimalType.forType(t1), t2) + case (t1: DecimalType, t2: IntegralType) => + compatibleType(t1, DecimalType.forType(t2)) + + case (TimestampNTZType, TimestampType) | (TimestampType, TimestampNTZType) => + TimestampType + + // strings and every string is a Json object. + case (_, _) => StringType + } + } + } +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/package.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/package.scala new file mode 100644 index 00000000..90270213 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/mapping/package.scala @@ -0,0 +1,14 @@ +package org.apache.spark.sql.arangodb.datasource + +import com.fasterxml.jackson.core.JsonFactory +import org.apache.spark.sql.arangodb.commons.ArangoDBConf +import org.apache.spark.sql.arangodb.datasource.mapping.json.JSONOptions + +package object mapping { + private[mapping] def createOptions(jsonFactory: JsonFactory, conf: ArangoDBConf) = + new JSONOptions(Map.empty[String, String], "UTC") { + override def buildJsonFactory(): JsonFactory = jsonFactory + + override val ignoreNullFields: Boolean = conf.mappingOptions.ignoreNullFields + } +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartition.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartition.scala new file mode 100644 index 00000000..ce354e0e --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartition.scala @@ -0,0 +1,15 @@ +package org.apache.spark.sql.arangodb.datasource.reader + +import org.apache.spark.sql.connector.read.InputPartition + +/** + * Partition corresponding to an Arango collection shard + * @param shardId collection shard id + * @param endpoint db endpoint to use to query the partition + */ +class ArangoCollectionPartition(val shardId: String, val endpoint: String) extends InputPartition + +/** + * Custom user queries will not be partitioned (eg. AQL traversals) + */ +object SingletonPartition extends InputPartition diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala new file mode 100644 index 00000000..d24ec187 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoCollectionPartitionReader.scala @@ -0,0 +1,65 @@ +package org.apache.spark.sql.arangodb.datasource.reader + +import com.arangodb.entity.CursorWarning +import org.apache.spark.internal.Logging +import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider +import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx +import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.FailureSafeParser +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.types.StructType + +import scala.annotation.tailrec +import scala.collection.JavaConverters.iterableAsScalaIterableConverter + + +class ArangoCollectionPartitionReader(inputPartition: ArangoCollectionPartition, ctx: PushDownCtx, opts: ArangoDBConf) + extends PartitionReader[InternalRow] with Logging { + + // override endpoints with partition endpoint + private val options = opts.updated(ArangoDBConf.ENDPOINTS, inputPartition.endpoint) + private val actualSchema = StructType(ctx.requiredSchema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord)) + private val parser = ArangoParserProvider().of(options.driverOptions.contentType, actualSchema, options) + private val safeParser = new FailureSafeParser[Array[Byte]]( + parser.parse, + options.readOptions.parseMode, + ctx.requiredSchema, + options.readOptions.columnNameOfCorruptRecord) + private val client = ArangoClient(options) + private val iterator = client.readCollectionPartition(inputPartition.shardId, ctx.filters, actualSchema) + + var rowIterator: Iterator[InternalRow] = _ + + // warnings of non stream AQL cursors are all returned along with the first batch + if (!options.readOptions.stream) logWarns() + + @tailrec + final override def next: Boolean = + if (iterator.hasNext) { + val current = iterator.next() + rowIterator = safeParser.parse(current.get) + if (rowIterator.hasNext) { + true + } else { + next + } + } else { + // FIXME: https://arangodb.atlassian.net/browse/BTS-671 + // stream AQL cursors' warnings are only returned along with the final batch + if (options.readOptions.stream) logWarns() + false + } + + override def get: InternalRow = rowIterator.next() + + override def close(): Unit = { + iterator.close() + client.shutdown() + } + + private def logWarns(): Unit = Option(iterator.getWarnings).foreach(_.asScala.foreach((w: CursorWarning) => + logWarning(s"Got AQL warning: [${w.getCode}] ${w.getMessage}") + )) + +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoPartitionReaderFactory.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoPartitionReaderFactory.scala new file mode 100644 index 00000000..feacc04b --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoPartitionReaderFactory.scala @@ -0,0 +1,13 @@ +package org.apache.spark.sql.arangodb.datasource.reader + +import org.apache.spark.sql.arangodb.commons.ArangoDBConf +import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} + +class ArangoPartitionReaderFactory(ctx: PushDownCtx, options: ArangoDBConf) extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = partition match { + case p: ArangoCollectionPartition => new ArangoCollectionPartitionReader(p, ctx, options) + case SingletonPartition => new ArangoQueryReader(ctx.requiredSchema, options) + } +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala new file mode 100644 index 00000000..8d975f59 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoQueryReader.scala @@ -0,0 +1,63 @@ +package org.apache.spark.sql.arangodb.datasource.reader + +import com.arangodb.entity.CursorWarning +import org.apache.spark.internal.Logging +import org.apache.spark.sql.arangodb.commons.mapping.ArangoParserProvider +import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.FailureSafeParser +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.types._ + +import scala.annotation.tailrec +import scala.collection.JavaConverters.iterableAsScalaIterableConverter + + +class ArangoQueryReader(schema: StructType, options: ArangoDBConf) extends PartitionReader[InternalRow] with Logging { + + private val actualSchema = StructType(schema.filterNot(_.name == options.readOptions.columnNameOfCorruptRecord)) + private val parser = ArangoParserProvider().of(options.driverOptions.contentType, actualSchema, options) + private val safeParser = new FailureSafeParser[Array[Byte]]( + parser.parse, + options.readOptions.parseMode, + schema, + options.readOptions.columnNameOfCorruptRecord) + private val client = ArangoClient(options) + private val iterator = client.readQuery() + + var rowIterator: Iterator[InternalRow] = _ + + // warnings of non stream AQL cursors are all returned along with the first batch + if (!options.readOptions.stream) logWarns() + + @tailrec + final override def next: Boolean = + if (iterator.hasNext) { + val current = iterator.next() + rowIterator = safeParser.parse(current.get) + if (rowIterator.hasNext) { + true + } else { + next + } + } else { + // FIXME: https://arangodb.atlassian.net/browse/BTS-671 + // stream AQL cursors' warnings are only returned along with the final batch + if (options.readOptions.stream) logWarns() + false + } + + override def get: InternalRow = rowIterator.next() + + override def close(): Unit = { + iterator.close() + client.shutdown() + } + + private def logWarns(): Unit = Option(iterator.getWarnings).foreach(_.asScala.foreach((w: CursorWarning) => + logWarning(s"Got AQL warning: [${w.getCode}] ${w.getMessage}") + )) + +} + + diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala new file mode 100644 index 00000000..3feedac5 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScan.scala @@ -0,0 +1,28 @@ +package org.apache.spark.sql.arangodb.datasource.reader + +import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf, ReadMode} +import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx +import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan} +import org.apache.spark.sql.types.StructType + +class ArangoScan(ctx: PushDownCtx, options: ArangoDBConf) extends Scan with Batch { + ExprUtils.verifyColumnNameOfCorruptRecord(ctx.requiredSchema, options.readOptions.columnNameOfCorruptRecord) + + override def readSchema(): StructType = ctx.requiredSchema + + override def toBatch: Batch = this + + override def planInputPartitions(): Array[InputPartition] = options.readOptions.readMode match { + case ReadMode.Query => Array(SingletonPartition) + case ReadMode.Collection => planCollectionPartitions() + } + + override def createReaderFactory(): PartitionReaderFactory = new ArangoPartitionReaderFactory(ctx, options) + + private def planCollectionPartitions(): Array[InputPartition] = + ArangoClient.getCollectionShardIds(options) + .zip(Stream.continually(options.driverOptions.endpoints).flatten) + .map(it => new ArangoCollectionPartition(it._1, it._2)) + +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala new file mode 100644 index 00000000..5b439438 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/reader/ArangoScanBuilder.scala @@ -0,0 +1,66 @@ +package org.apache.spark.sql.arangodb.datasource.reader + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.arangodb.commons.{ArangoDBConf, ReadMode} +import org.apache.spark.sql.arangodb.commons.filter.{FilterSupport, PushableFilter} +import org.apache.spark.sql.arangodb.commons.utils.PushDownCtx +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType + +class ArangoScanBuilder(options: ArangoDBConf, tableSchema: StructType) extends ScanBuilder + with SupportsPushDownFilters + with SupportsPushDownRequiredColumns + with Logging { + + private var readSchema: StructType = _ + + // fully or partially applied filters + private var appliedPushableFilters: Array[PushableFilter] = Array() + private var appliedSparkFilters: Array[Filter] = Array() + + override def build(): Scan = new ArangoScan(new PushDownCtx(readSchema, appliedPushableFilters), options) + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + options.readOptions.readMode match { + case ReadMode.Collection => pushFiltersReadModeCollection(filters) + case ReadMode.Query => filters + } + } + + private def pushFiltersReadModeCollection(filters: Array[Filter]): Array[Filter] = { + // filters related to columnNameOfCorruptRecord are not pushed down + val isCorruptRecordFilter = (f: Filter) => f.references.contains(options.readOptions.columnNameOfCorruptRecord) + val ignoredFilters = filters.filter(isCorruptRecordFilter) + val filtersBySupport = filters + .filterNot(isCorruptRecordFilter) + .map(f => (f, PushableFilter(f, tableSchema))) + .groupBy(_._2.support()) + + val fullSupp = filtersBySupport.getOrElse(FilterSupport.FULL, Array()) + val partialSupp = filtersBySupport.getOrElse(FilterSupport.PARTIAL, Array()) + val noneSupp = filtersBySupport.getOrElse(FilterSupport.NONE, Array()).map(_._1) ++ ignoredFilters + + val appliedFilters = fullSupp ++ partialSupp + appliedPushableFilters = appliedFilters.map(_._2) + appliedSparkFilters = appliedFilters.map(_._1) + + if (fullSupp.nonEmpty) { + logInfo(s"Filters fully applied in AQL:\n\t${fullSupp.map(_._1).mkString("\n\t")}") + } + if (partialSupp.nonEmpty) { + logInfo(s"Filters partially applied in AQL:\n\t${partialSupp.map(_._1).mkString("\n\t")}") + } + if (noneSupp.nonEmpty) { + logInfo(s"Filters not applied in AQL:\n\t${noneSupp.mkString("\n\t")}") + } + + partialSupp.map(_._1) ++ noneSupp + } + + override def pushedFilters(): Array[Filter] = appliedSparkFilters + + override def pruneColumns(requiredSchema: StructType): Unit = { + this.readSchema = requiredSchema + } +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoBatchWriter.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoBatchWriter.scala new file mode 100644 index 00000000..e7680f0e --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoBatchWriter.scala @@ -0,0 +1,30 @@ +package org.apache.spark.sql.arangodb.datasource.writer + +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf} +import org.apache.spark.sql.arangodb.commons.exceptions.DataWriteAbortException +import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage} +import org.apache.spark.sql.types.StructType + +class ArangoBatchWriter(schema: StructType, options: ArangoDBConf, mode: SaveMode) extends BatchWrite { + + override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = + new ArangoDataWriterFactory(schema, options) + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + // nothing to do here + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + val client = ArangoClient(options) + mode match { + case SaveMode.Append => throw new DataWriteAbortException( + "Cannot abort with SaveMode.Append: the underlying data source may require manual cleanup.") + case SaveMode.Overwrite => client.truncate() + case SaveMode.ErrorIfExists => ??? + case SaveMode.Ignore => ??? + } + client.shutdown() + } + +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriter.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriter.scala new file mode 100644 index 00000000..b6577254 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriter.scala @@ -0,0 +1,136 @@ +package org.apache.spark.sql.arangodb.datasource.writer + +import com.arangodb.{ArangoDBException, ArangoDBMultipleException} +import com.arangodb.model.OverwriteMode +import com.arangodb.util.RawBytes +import org.apache.spark.internal.Logging +import org.apache.spark.sql.arangodb.commons.exceptions.{ArangoDBDataWriterException, DataWriteAbortException} +import org.apache.spark.sql.arangodb.commons.mapping.{ArangoGenerator, ArangoGeneratorProvider} +import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.types.StructType + +import java.io.ByteArrayOutputStream +import java.net.{ConnectException, UnknownHostException} +import scala.annotation.tailrec +import scala.collection.JavaConverters.iterableAsScalaIterableConverter +import scala.util.Random + +class ArangoDataWriter(schema: StructType, options: ArangoDBConf, partitionId: Int) + extends DataWriter[InternalRow] with Logging { + + private var failures = 0 + private var exceptions: List[Exception] = List() + private var requestCount = 0L + private var endpointIdx = partitionId + private val endpoints = Stream.continually(options.driverOptions.endpoints).flatten + private val rnd = new Random() + private var client: ArangoClient = createClient() + private var batchCount: Int = _ + private var outStream: ByteArrayOutputStream = _ + private var vpackGenerator: ArangoGenerator = _ + + initBatch() + + override def write(record: InternalRow): Unit = { + vpackGenerator.write(record) + vpackGenerator.flush() + batchCount += 1 + if (batchCount == options.writeOptions.batchSize || outStream.size() > options.writeOptions.byteBatchSize) { + flushBatch() + initBatch() + } + } + + override def commit(): WriterCommitMessage = { + flushBatch() + null // scalastyle:ignore null + } + + /** + * Data cleanup will happen in [[ArangoBatchWriter.abort()]] + */ + override def abort(): Unit = if (!canRetry) { + client.shutdown() + throw new DataWriteAbortException( + "Task cannot be retried. To make batch writes idempotent, so that they can be retried, consider using " + + "'keep.null=true' (default) and 'overwrite.mode=(ignore|replace|update)'.") + } + + override def close(): Unit = { + client.shutdown() + } + + private def createClient() = ArangoClient(options.updated(ArangoDBConf.ENDPOINTS, endpoints(endpointIdx))) + + private def canRetry: Boolean = ArangoDataWriter.canRetry(schema, options) + + private def initBatch(): Unit = { + batchCount = 0 + outStream = new ByteArrayOutputStream() + vpackGenerator = ArangoGeneratorProvider().of(options.driverOptions.contentType, schema, outStream, options) + vpackGenerator.writeStartArray() + } + + private def flushBatch(): Unit = { + vpackGenerator.writeEndArray() + vpackGenerator.close() + vpackGenerator.flush() + logDebug(s"flushBatch(), bufferSize: ${outStream.size()}") + saveDocuments(RawBytes.of(outStream.toByteArray)) + } + + @tailrec private def saveDocuments(payload: RawBytes): Unit = { + try { + requestCount += 1 + logDebug(s"Sending request #$requestCount for partition $partitionId") + client.saveDocuments(payload) + logDebug(s"Received response #$requestCount for partition $partitionId") + failures = 0 + exceptions = List() + } catch { + case e: Exception => + client.shutdown() + failures += 1 + exceptions = e :: exceptions + endpointIdx += 1 + if ((canRetry || isConnectionException(e)) && failures < options.writeOptions.maxAttempts) { + val delay = computeDelay() + logWarning(s"Got exception while saving documents, retrying in $delay ms:", e) + Thread.sleep(delay) + client = createClient() + saveDocuments(payload) + } else { + throw new ArangoDBDataWriterException(exceptions.reverse.toArray) + } + } + } + + private def computeDelay(): Int = { + val min = options.writeOptions.minRetryDelay + val max = options.writeOptions.maxRetryDelay + val diff = max - min + val delta = if (diff <= 0) 0 else rnd.nextInt(diff) + min + delta + } + + private def isConnectionException(e: Throwable): Boolean = e match { + case ae: ArangoDBException => isConnectionException(ae.getCause) + case me: ArangoDBMultipleException => me.getExceptions.asScala.forall(isConnectionException) + case _: ConnectException => true + case _: UnknownHostException => true + case _ => false + } + +} + +object ArangoDataWriter { + def canRetry(schema: StructType, options: ArangoDBConf): Boolean = + schema.exists(p => p.name == "_key" && !p.nullable) && (options.writeOptions.overwriteMode match { + case OverwriteMode.ignore => true + case OverwriteMode.replace => true + case OverwriteMode.update => options.writeOptions.keepNull + case OverwriteMode.conflict => false + }) +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriterFactory.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriterFactory.scala new file mode 100644 index 00000000..d4513fd7 --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoDataWriterFactory.scala @@ -0,0 +1,12 @@ +package org.apache.spark.sql.arangodb.datasource.writer + +import org.apache.spark.sql.arangodb.commons.ArangoDBConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory} +import org.apache.spark.sql.types.StructType + +class ArangoDataWriterFactory(schema: StructType, options: ArangoDBConf) extends DataWriterFactory { + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + new ArangoDataWriter(schema, options, partitionId) + } +} diff --git a/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoWriterBuilder.scala b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoWriterBuilder.scala new file mode 100644 index 00000000..4336e88f --- /dev/null +++ b/arangodb-spark-datasource-3.4/src/main/scala/org/apache/spark/sql/arangodb/datasource/writer/ArangoWriterBuilder.scala @@ -0,0 +1,93 @@ +package org.apache.spark.sql.arangodb.datasource.writer + +import com.arangodb.entity.CollectionType +import com.arangodb.model.OverwriteMode +import org.apache.spark.internal.Logging +import org.apache.spark.sql.arangodb.commons.{ArangoClient, ArangoDBConf, ContentType} +import org.apache.spark.sql.connector.write.{BatchWrite, SupportsTruncate, WriteBuilder} +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} +import org.apache.spark.sql.{AnalysisException, SaveMode} + +class ArangoWriterBuilder(schema: StructType, options: ArangoDBConf) + extends WriteBuilder with SupportsTruncate with Logging { + + private var mode: SaveMode = SaveMode.Append + validateConfig() + + override def buildForBatch(): BatchWrite = { + val client = ArangoClient(options) + if (!client.collectionExists()) { + client.createCollection() + } + client.shutdown() + + val updatedOptions = options.updated(ArangoDBConf.OVERWRITE_MODE, mode match { + case SaveMode.Append => options.writeOptions.overwriteMode.getValue + case _ => OverwriteMode.ignore.getValue + }) + + logSummary(updatedOptions) + new ArangoBatchWriter(schema, updatedOptions, mode) + } + + override def truncate(): WriteBuilder = { + mode = SaveMode.Overwrite + if (options.writeOptions.confirmTruncate) { + val client = ArangoClient(options) + if (client.collectionExists()) { + client.truncate() + } else { + client.createCollection() + } + client.shutdown() + this + } else { + throw new AnalysisException( + "You are attempting to use overwrite mode which will truncate this collection prior to inserting data. If " + + "you just want to change data already in the collection set save mode 'append' and " + + s"'overwrite.mode=(replace|update)'. To actually truncate set '${ArangoDBConf.CONFIRM_TRUNCATE}=true'.") + } + } + + private def validateConfig(): Unit = { + if (options.driverOptions.contentType == ContentType.JSON && hasDecimalTypeFields) { + throw new UnsupportedOperationException("Cannot write DecimalType when using contentType=json") + } + + if (options.writeOptions.collectionType == CollectionType.EDGES && + !schema.exists(p => p.name == "_from" && p.dataType == StringType && !p.nullable) + ) { + throw new IllegalArgumentException("Writing edge collection requires non nullable string field named _from.") + } + + if (options.writeOptions.collectionType == CollectionType.EDGES && + !schema.exists(p => p.name == "_to" && p.dataType == StringType && !p.nullable) + ) { + throw new IllegalArgumentException("Writing edge collection requires non nullable string field named _to.") + } + } + + private def hasDecimalTypeFields: Boolean = + schema.existsRecursively { + case _: DecimalType => true + case _ => false + } + + private def logSummary(updatedOptions: ArangoDBConf): Unit = { + val canRetry = ArangoDataWriter.canRetry(schema, updatedOptions) + + logInfo(s"Using save mode: $mode") + logInfo(s"Using write configuration: ${updatedOptions.writeOptions}") + logInfo(s"Using mapping configuration: ${updatedOptions.mappingOptions}") + logInfo(s"Can retry: $canRetry") + + if (!canRetry) { + logWarning( + """The provided configuration does not allow idempotent requests: write failures will not be retried and lead + |to task failure. Speculative task executions could fail or write incorrect data.""" + .stripMargin.replaceAll("\n", "") + ) + } + } + +} diff --git a/bin/clean.sh b/bin/clean.sh index be538923..9a8398b0 100755 --- a/bin/clean.sh +++ b/bin/clean.sh @@ -7,3 +7,5 @@ mvn clean -Pspark-3.2 -Pscala-2.12 mvn clean -Pspark-3.2 -Pscala-2.13 mvn clean -Pspark-3.3 -Pscala-2.12 mvn clean -Pspark-3.3 -Pscala-2.13 +mvn clean -Pspark-3.4 -Pscala-2.12 +mvn clean -Pspark-3.4 -Pscala-2.13 diff --git a/bin/test.sh b/bin/test.sh index 498635c6..fe1f78be 100755 --- a/bin/test.sh +++ b/bin/test.sh @@ -23,3 +23,9 @@ mvn test -Pspark-3.3 -Pscala-2.12 mvn clean -Pspark-3.3 -Pscala-2.13 mvn test -Pspark-3.3 -Pscala-2.13 + +mvn clean -Pspark-3.4 -Pscala-2.12 +mvn test -Pspark-3.4 -Pscala-2.12 + +mvn clean -Pspark-3.4 -Pscala-2.13 +mvn test -Pspark-3.4 -Pscala-2.13 diff --git a/demo/README.md b/demo/README.md index 9aced3c6..b82a5ee8 100644 --- a/demo/README.md +++ b/demo/README.md @@ -79,7 +79,7 @@ docker run -it --rm \ -v $(pwd):/demo \ -v $(pwd)/docker/.ivy2:/opt/bitnami/spark/.ivy2 \ --network arangodb \ - docker.io/bitnami/spark:3.2.1 \ + docker.io/bitnami/spark:3.2.4 \ ./bin/spark-submit --master spark://spark-master:7077 \ --packages="com.arangodb:arangodb-spark-datasource-3.2_2.12:$ARANGO_SPARK_VERSION" \ --class Demo /demo/target/demo-$ARANGO_SPARK_VERSION.jar diff --git a/demo/docker/start_spark_3.2.sh b/demo/docker/start_spark_3.2.sh index c64ecad7..1dba1f70 100755 --- a/demo/docker/start_spark_3.2.sh +++ b/demo/docker/start_spark_3.2.sh @@ -9,7 +9,7 @@ docker run -d --network arangodb --ip 172.28.10.1 --name spark-master -h spark-m -e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \ -e SPARK_SSL_ENABLED=no \ -v $(pwd)/docker/import:/import \ - docker.io/bitnami/spark:3.2.1 + docker.io/bitnami/spark:3.2.4 docker run -d --network arangodb --ip 172.28.10.11 --name spark-worker-1 -h spark-worker-1 \ -e SPARK_MODE=worker \ @@ -21,7 +21,7 @@ docker run -d --network arangodb --ip 172.28.10.11 --name spark-worker-1 -h spar -e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \ -e SPARK_SSL_ENABLED=no \ -v $(pwd)/docker/import:/import \ - docker.io/bitnami/spark:3.2.1 + docker.io/bitnami/spark:3.2.4 docker run -d --network arangodb --ip 172.28.10.12 --name spark-worker-2 -h spark-worker-2 \ -e SPARK_MODE=worker \ @@ -33,7 +33,7 @@ docker run -d --network arangodb --ip 172.28.10.12 --name spark-worker-2 -h spar -e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \ -e SPARK_SSL_ENABLED=no \ -v $(pwd)/docker/import:/import \ - docker.io/bitnami/spark:3.2.1 + docker.io/bitnami/spark:3.2.4 docker run -d --network arangodb --ip 172.28.10.13 --name spark-worker-3 -h spark-worker-3 \ -e SPARK_MODE=worker \ @@ -45,4 +45,4 @@ docker run -d --network arangodb --ip 172.28.10.13 --name spark-worker-3 -h spar -e SPARK_LOCAL_STORAGE_ENCRYPTION_ENABLED=no \ -e SPARK_SSL_ENABLED=no \ -v $(pwd)/docker/import:/import \ - docker.io/bitnami/spark:3.2.1 + docker.io/bitnami/spark:3.2.4 diff --git a/demo/pom.xml b/demo/pom.xml index ceb0609c..e9a43656 100644 --- a/demo/pom.xml +++ b/demo/pom.xml @@ -55,21 +55,28 @@ spark-3.2 + + true + - 3.2.1 + 3.2.4 3.2 spark-3.3 - - true - 3.3.2 3.3 + + spark-3.4 + + 3.4.0 + 3.4 + + diff --git a/docker/start_spark_2.4.sh b/docker/start_spark_2.4.sh deleted file mode 100755 index 09999280..00000000 --- a/docker/start_spark_2.4.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -docker network create arangodb --subnet 172.28.0.0/16 -docker run --network arangodb --ip 172.28.10.1 --name spark-master -h spark-master -e ENABLE_INIT_DAEMON=false -d bde2020/spark-master:2.4.5-hadoop2.7 -docker run --network arangodb --ip 172.28.10.11 --name spark-worker-1 -h spark-worker-1 -e SPARK_WORKER_CORES=1 -e ENABLE_INIT_DAEMON=false -d bde2020/spark-worker:2.4.5-hadoop2.7 -docker run --network arangodb --ip 172.28.10.12 --name spark-worker-2 -h spark-worker-2 -e SPARK_WORKER_CORES=1 -e ENABLE_INIT_DAEMON=false -d bde2020/spark-worker:2.4.5-hadoop2.7 -docker run --network arangodb --ip 172.28.10.13 --name spark-worker-3 -h spark-worker-3 -e SPARK_WORKER_CORES=1 -e ENABLE_INIT_DAEMON=false -d bde2020/spark-worker:2.4.5-hadoop2.7 diff --git a/docker/start_spark_3.1.sh b/docker/start_spark_3.1.sh deleted file mode 100755 index e3c3bb06..00000000 --- a/docker/start_spark_3.1.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -docker network create arangodb --subnet 172.28.0.0/16 -docker run --network arangodb --ip 172.28.10.1 --name spark-master -h spark-master -e ENABLE_INIT_DAEMON=false -d bde2020/spark-master:3.1.1-hadoop3.2 -docker run --network arangodb --ip 172.28.10.11 --name spark-worker-1 -h spark-worker-1 -e SPARK_WORKER_CORES=1 -e ENABLE_INIT_DAEMON=false -d bde2020/spark-worker:3.1.1-hadoop3.2 -docker run --network arangodb --ip 172.28.10.12 --name spark-worker-2 -h spark-worker-2 -e SPARK_WORKER_CORES=1 -e ENABLE_INIT_DAEMON=false -d bde2020/spark-worker:3.1.1-hadoop3.2 -docker run --network arangodb --ip 172.28.10.13 --name spark-worker-3 -h spark-worker-3 -e SPARK_WORKER_CORES=1 -e ENABLE_INIT_DAEMON=false -d bde2020/spark-worker:3.1.1-hadoop3.2 diff --git a/docker/stop.sh b/docker/stop.sh deleted file mode 100755 index c3f9aa22..00000000 --- a/docker/stop.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -docker exec adb /app/arangodb stop -sleep 1 -docker rm -f \ - adb \ - spark-master \ - spark-worker-1 \ - spark-worker-2 \ - spark-worker-3 diff --git a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala index 6fe0e22f..a56533eb 100644 --- a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala +++ b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/DeserializationCastTest.scala @@ -62,6 +62,7 @@ class DeserializationCastTest extends BaseSparkTest { def nullToIntegerCast(contentType: String): Unit = { // FIXME: DE-599 assumeTrue(!SPARK_VERSION.startsWith("3.3")) + assumeTrue(!SPARK_VERSION.startsWith("3.4")) doTestImplicitCast( StructType(Array(StructField("a", IntegerType, nullable = false))), @@ -76,6 +77,7 @@ class DeserializationCastTest extends BaseSparkTest { def nullToDoubleCast(contentType: String): Unit = { // FIXME: DE-599 assumeTrue(!SPARK_VERSION.startsWith("3.3")) + assumeTrue(!SPARK_VERSION.startsWith("3.4")) doTestImplicitCast( StructType(Array(StructField("a", DoubleType, nullable = false))), @@ -90,6 +92,7 @@ class DeserializationCastTest extends BaseSparkTest { def nullAsBoolean(contentType: String): Unit = { // FIXME: DE-599 assumeTrue(!SPARK_VERSION.startsWith("3.3")) + assumeTrue(!SPARK_VERSION.startsWith("3.4")) doTestImplicitCast( StructType(Array(StructField("a", BooleanType, nullable = false))), diff --git a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/AbortTest.scala b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/AbortTest.scala index 242b4651..6386ef5b 100644 --- a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/AbortTest.scala +++ b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/AbortTest.scala @@ -79,7 +79,8 @@ class AbortTest extends BaseSparkTest { ArangoDBConf.PROTOCOL -> protocol, ArangoDBConf.CONTENT_TYPE -> contentType, ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue, - ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name + ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name, + ArangoDBConf.BATCH_SIZE -> "9" )) .save() }) @@ -113,7 +114,8 @@ class AbortTest extends BaseSparkTest { ArangoDBConf.PROTOCOL -> protocol, ArangoDBConf.CONTENT_TYPE -> contentType, ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue, - ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name + ArangoDBConf.COLLECTION_TYPE -> CollectionType.EDGE.name, + ArangoDBConf.BATCH_SIZE -> "9" )) .save() }) @@ -150,7 +152,8 @@ class AbortTest extends BaseSparkTest { ArangoDBConf.COLLECTION -> collectionName, ArangoDBConf.PROTOCOL -> protocol, ArangoDBConf.CONTENT_TYPE -> contentType, - ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue + ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue, + ArangoDBConf.BATCH_SIZE -> "9" )) .save() }) @@ -181,14 +184,22 @@ class AbortTest extends BaseSparkTest { ArangoDBConf.PROTOCOL -> protocol, ArangoDBConf.CONTENT_TYPE -> contentType, ArangoDBConf.CONFIRM_TRUNCATE -> "true", - ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue + ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.replace.getValue, + ArangoDBConf.BATCH_SIZE -> "1" )) .save() }) assertThat(thrown).isInstanceOf(classOf[SparkException]) - assertThat(thrown.getCause.getCause).isInstanceOf(classOf[ArangoDBDataWriterException]) - val rootEx = thrown.getCause.getCause.getCause + + val cause = if(SPARK_VERSION.startsWith("3.4")) { + thrown.getCause + } else { + thrown.getCause.getCause + } + + assertThat(cause).isInstanceOf(classOf[ArangoDBDataWriterException]) + val rootEx = cause.getCause assertThat(rootEx).isInstanceOf(classOf[ArangoDBMultiException]) val errors = rootEx.asInstanceOf[ArangoDBMultiException].errors assertThat(errors.length).isEqualTo(1) diff --git a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/OverwriteModeTest.scala b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/OverwriteModeTest.scala index d06bd78e..3cd8c3f4 100644 --- a/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/OverwriteModeTest.scala +++ b/integration-tests/src/test/scala/org/apache/spark/sql/arangodb/datasource/write/OverwriteModeTest.scala @@ -58,7 +58,8 @@ class OverwriteModeTest extends BaseSparkTest { ArangoDBConf.COLLECTION -> collectionName, ArangoDBConf.PROTOCOL -> protocol, ArangoDBConf.CONTENT_TYPE -> contentType, - ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.conflict.getValue + ArangoDBConf.OVERWRITE_MODE -> OverwriteMode.conflict.getValue, + ArangoDBConf.BATCH_SIZE -> "3" )) .save() }) diff --git a/pom.xml b/pom.xml index 89ca2ba2..bc24f921 100644 --- a/pom.xml +++ b/pom.xml @@ -117,6 +117,15 @@ 4.1.0 + + spark-3.4 + + 3.4.0 + 3.4 + + 4.1.0 + + no-deploy