diff --git a/.github/labeler.yml b/.github/labeler.yml
index cf1d2a7117203..84dfa35f2627e 100644
--- a/.github/labeler.yml
+++ b/.github/labeler.yml
@@ -155,3 +155,6 @@ CONNECT:
- "connector/connect/**/*"
- "**/sql/sparkconnect/**/*"
- "python/pyspark/sql/**/connect/**/*"
+PROTOBUF:
+ - "connector/protobuf/**/*"
+ - "python/pyspark/sql/protobuf/**/*"
\ No newline at end of file
diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml
new file mode 100644
index 0000000000000..0515f128b8d63
--- /dev/null
+++ b/connector/protobuf/pom.xml
@@ -0,0 +1,115 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.12
+ 3.4.0-SNAPSHOT
+ ../../pom.xml
+
+
+ spark-protobuf_2.12
+
+ protobuf
+ 3.21.1
+
+ jar
+ Spark Protobuf
+ https://spark.apache.org/
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+ test
+
+
+ org.apache.spark
+ spark-tags_${scala.binary.version}
+
+
+
+ com.google.protobuf
+ protobuf-java
+ ${protobuf.version}
+ compile
+
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+
+ false
+
+
+ com.google.protobuf:*
+
+
+
+
+ com.google.protobuf
+ ${spark.shade.packageName}.spark-protobuf.protobuf
+
+ com.google.protobuf.**
+
+
+
+
+
+
+
+
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
new file mode 100644
index 0000000000000..145100268c232
--- /dev/null
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.protobuf
+
+import com.google.protobuf.DynamicMessage
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.protobuf.utils.ProtobufUtils
+import org.apache.spark.sql.types.{BinaryType, DataType}
+
+private[protobuf] case class CatalystDataToProtobuf(
+ child: Expression,
+ descFilePath: String,
+ messageName: String)
+ extends UnaryExpression {
+
+ override def dataType: DataType = BinaryType
+
+ @transient private lazy val protoType =
+ ProtobufUtils.buildDescriptor(descFilePath, messageName)
+
+ @transient private lazy val serializer =
+ new ProtobufSerializer(child.dataType, protoType, child.nullable)
+
+ override def nullSafeEval(input: Any): Any = {
+ val dynamicMessage = serializer.serialize(input).asInstanceOf[DynamicMessage]
+ dynamicMessage.toByteArray
+ }
+
+ override def prettyName: String = "to_protobuf"
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val expr = ctx.addReferenceObj("this", this)
+ defineCodeGen(ctx, ev, input => s"(byte[]) $expr.nullSafeEval($input)")
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): CatalystDataToProtobuf =
+ copy(child = newChild)
+}
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
new file mode 100644
index 0000000000000..f08f876799723
--- /dev/null
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.protobuf
+
+import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
+
+import com.google.protobuf.DynamicMessage
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
+import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode}
+import org.apache.spark.sql.protobuf.utils.{ProtobufOptions, ProtobufUtils, SchemaConverters}
+import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, StructType}
+
+private[protobuf] case class ProtobufDataToCatalyst(
+ child: Expression,
+ descFilePath: String,
+ messageName: String,
+ options: Map[String, String])
+ extends UnaryExpression
+ with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
+
+ override lazy val dataType: DataType = {
+ val dt = SchemaConverters.toSqlType(messageDescriptor).dataType
+ parseMode match {
+ // With PermissiveMode, the output Catalyst row might contain columns of null values for
+ // corrupt records, even if some of the columns are not nullable in the user-provided schema.
+ // Therefore we force the schema to be all nullable here.
+ case PermissiveMode => dt.asNullable
+ case _ => dt
+ }
+ }
+
+ override def nullable: Boolean = true
+
+ private lazy val protobufOptions = ProtobufOptions(options)
+
+ @transient private lazy val messageDescriptor =
+ ProtobufUtils.buildDescriptor(descFilePath, messageName)
+
+ @transient private lazy val fieldsNumbers =
+ messageDescriptor.getFields.asScala.map(f => f.getNumber)
+
+ @transient private lazy val deserializer = new ProtobufDeserializer(messageDescriptor, dataType)
+
+ @transient private var result: DynamicMessage = _
+
+ @transient private lazy val parseMode: ParseMode = {
+ val mode = protobufOptions.parseMode
+ if (mode != PermissiveMode && mode != FailFastMode) {
+ throw new AnalysisException(unacceptableModeMessage(mode.name))
+ }
+ mode
+ }
+
+ private def unacceptableModeMessage(name: String): String = {
+ s"from_protobuf() doesn't support the $name mode. " +
+ s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}."
+ }
+
+ @transient private lazy val nullResultRow: Any = dataType match {
+ case st: StructType =>
+ val resultRow = new SpecificInternalRow(st.map(_.dataType))
+ for (i <- 0 until st.length) {
+ resultRow.setNullAt(i)
+ }
+ resultRow
+
+ case _ =>
+ null
+ }
+
+ private def handleException(e: Throwable): Any = {
+ parseMode match {
+ case PermissiveMode =>
+ nullResultRow
+ case FailFastMode =>
+ throw new SparkException(
+ "Malformed records are detected in record parsing. " +
+ s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " +
+ "result, try setting the option 'mode' as 'PERMISSIVE'.",
+ e)
+ case _ =>
+ throw new AnalysisException(unacceptableModeMessage(parseMode.name))
+ }
+ }
+
+ override def nullSafeEval(input: Any): Any = {
+ val binary = input.asInstanceOf[Array[Byte]]
+ try {
+ result = DynamicMessage.parseFrom(messageDescriptor, binary)
+ val unknownFields = result.getUnknownFields
+ if (!unknownFields.asMap().isEmpty) {
+ unknownFields.asMap().keySet().asScala.map { number =>
+ {
+ if (fieldsNumbers.contains(number)) {
+ return handleException(
+ new Throwable(s"Type mismatch encountered for field:" +
+ s" ${messageDescriptor.getFields.get(number)}"))
+ }
+ }
+ }
+ }
+ val deserialized = deserializer.deserialize(result)
+ assert(
+ deserialized.isDefined,
+ "Protobuf deserializer cannot return an empty result because filters are not pushed down")
+ deserialized.get
+ } catch {
+ // There could be multiple possible exceptions here, e.g. java.io.IOException,
+ // ProtoRuntimeException, ArrayIndexOutOfBoundsException, etc.
+ // To make it simple, catch all the exceptions here.
+ case NonFatal(e) =>
+ handleException(e)
+ }
+ }
+
+ override def prettyName: String = "from_protobuf"
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val expr = ctx.addReferenceObj("this", this)
+ nullSafeCodeGen(
+ ctx,
+ ev,
+ eval => {
+ val result = ctx.freshName("result")
+ val dt = CodeGenerator.boxedType(dataType)
+ s"""
+ $dt $result = ($dt) $expr.nullSafeEval($eval);
+ if ($result == null) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = $result;
+ }
+ """
+ })
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): ProtobufDataToCatalyst =
+ copy(child = newChild)
+}
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
new file mode 100644
index 0000000000000..0403b741ebfa7
--- /dev/null
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala
@@ -0,0 +1,357 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.protobuf
+
+import java.util.concurrent.TimeUnit
+
+import com.google.protobuf.{ByteString, DynamicMessage, Message}
+import com.google.protobuf.Descriptors._
+import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
+
+import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters}
+import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.protobuf.utils.ProtobufUtils
+import org.apache.spark.sql.protobuf.utils.ProtobufUtils.ProtoMatchedField
+import org.apache.spark.sql.protobuf.utils.ProtobufUtils.toFieldStr
+import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+private[sql] class ProtobufDeserializer(
+ rootDescriptor: Descriptor,
+ rootCatalystType: DataType,
+ filters: StructFilters) {
+
+ def this(rootDescriptor: Descriptor, rootCatalystType: DataType) = {
+ this(rootDescriptor, rootCatalystType, new NoopFilters)
+ }
+
+ private val converter: Any => Option[InternalRow] =
+ try {
+ rootCatalystType match {
+ // A shortcut for empty schema.
+ case st: StructType if st.isEmpty =>
+ (_: Any) => Some(InternalRow.empty)
+
+ case st: StructType =>
+ val resultRow = new SpecificInternalRow(st.map(_.dataType))
+ val fieldUpdater = new RowUpdater(resultRow)
+ val applyFilters = filters.skipRow(resultRow, _)
+ val writer = getRecordWriter(rootDescriptor, st, Nil, Nil, applyFilters)
+ (data: Any) => {
+ val record = data.asInstanceOf[DynamicMessage]
+ val skipRow = writer(fieldUpdater, record)
+ if (skipRow) None else Some(resultRow)
+ }
+ }
+ } catch {
+ case ise: IncompatibleSchemaException =>
+ throw new IncompatibleSchemaException(
+ s"Cannot convert Protobuf type ${rootDescriptor.getName} " +
+ s"to SQL type ${rootCatalystType.sql}.",
+ ise)
+ }
+
+ def deserialize(data: Message): Option[InternalRow] = converter(data)
+
+ private def newArrayWriter(
+ protoField: FieldDescriptor,
+ protoPath: Seq[String],
+ catalystPath: Seq[String],
+ elementType: DataType,
+ containsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = {
+
+ val protoElementPath = protoPath :+ "element"
+ val elementWriter =
+ newWriter(protoField, elementType, protoElementPath, catalystPath :+ "element")
+ (updater, ordinal, value) =>
+ val collection = value.asInstanceOf[java.util.Collection[Any]]
+ val result = createArrayData(elementType, collection.size())
+ val elementUpdater = new ArrayDataUpdater(result)
+
+ var i = 0
+ val iterator = collection.iterator()
+ while (iterator.hasNext) {
+ val element = iterator.next()
+ if (element == null) {
+ if (!containsNull) {
+ throw QueryCompilationErrors.nullableArrayOrMapElementError(protoElementPath)
+ } else {
+ elementUpdater.setNullAt(i)
+ }
+ } else {
+ elementWriter(elementUpdater, i, element)
+ }
+ i += 1
+ }
+
+ updater.set(ordinal, result)
+ }
+
+ private def newMapWriter(
+ protoType: FieldDescriptor,
+ protoPath: Seq[String],
+ catalystPath: Seq[String],
+ keyType: DataType,
+ valueType: DataType,
+ valueContainsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = {
+ val keyField = protoType.getMessageType.getFields.get(0)
+ val valueField = protoType.getMessageType.getFields.get(1)
+ val keyWriter = newWriter(keyField, keyType, protoPath :+ "key", catalystPath :+ "key")
+ val valueWriter =
+ newWriter(valueField, valueType, protoPath :+ "value", catalystPath :+ "value")
+ (updater, ordinal, value) =>
+ if (value != null) {
+ val messageList = value.asInstanceOf[java.util.List[com.google.protobuf.Message]]
+ val valueArray = createArrayData(valueType, messageList.size())
+ val valueUpdater = new ArrayDataUpdater(valueArray)
+ val keyArray = createArrayData(keyType, messageList.size())
+ val keyUpdater = new ArrayDataUpdater(keyArray)
+ var i = 0
+ messageList.forEach { field =>
+ {
+ keyWriter(keyUpdater, i, field.getField(keyField))
+ if (field.getField(valueField) == null) {
+ if (!valueContainsNull) {
+ throw QueryCompilationErrors.nullableArrayOrMapElementError(protoPath)
+ } else {
+ valueUpdater.setNullAt(i)
+ }
+ } else {
+ valueWriter(valueUpdater, i, field.getField(valueField))
+ }
+ }
+ i += 1
+ }
+ updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))
+ }
+ }
+
+ /**
+ * Creates a writer to write Protobuf values to Catalyst values at the given ordinal with the
+ * given updater.
+ */
+ private def newWriter(
+ protoType: FieldDescriptor,
+ catalystType: DataType,
+ protoPath: Seq[String],
+ catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = {
+ val errorPrefix = s"Cannot convert Protobuf ${toFieldStr(protoPath)} to " +
+ s"SQL ${toFieldStr(catalystPath)} because "
+ val incompatibleMsg = errorPrefix +
+ s"schema is incompatible (protoType = ${protoType} ${protoType.toProto.getLabel} " +
+ s"${protoType.getJavaType} ${protoType.getType}, sqlType = ${catalystType.sql})"
+
+ (protoType.getJavaType, catalystType) match {
+
+ case (null, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal)
+
+ // TODO: we can avoid boxing if future version of Protobuf provide primitive accessors.
+ case (BOOLEAN, BooleanType) =>
+ (updater, ordinal, value) => updater.setBoolean(ordinal, value.asInstanceOf[Boolean])
+
+ case (INT, IntegerType) =>
+ (updater, ordinal, value) => updater.setInt(ordinal, value.asInstanceOf[Int])
+
+ case (INT, ByteType) =>
+ (updater, ordinal, value) => updater.setByte(ordinal, value.asInstanceOf[Byte])
+
+ case (INT, ShortType) =>
+ (updater, ordinal, value) => updater.setShort(ordinal, value.asInstanceOf[Short])
+
+ case (BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING,
+ ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated =>
+ newArrayWriter(protoType, protoPath, catalystPath, dataType, containsNull)
+
+ case (LONG, LongType) =>
+ (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long])
+
+ case (FLOAT, FloatType) =>
+ (updater, ordinal, value) => updater.setFloat(ordinal, value.asInstanceOf[Float])
+
+ case (DOUBLE, DoubleType) =>
+ (updater, ordinal, value) => updater.setDouble(ordinal, value.asInstanceOf[Double])
+
+ case (STRING, StringType) =>
+ (updater, ordinal, value) =>
+ val str = value match {
+ case s: String => UTF8String.fromString(s)
+ }
+ updater.set(ordinal, str)
+
+ case (BYTE_STRING, BinaryType) =>
+ (updater, ordinal, value) =>
+ val byte_array = value match {
+ case s: ByteString => s.toByteArray
+ case _ => throw new Exception("Invalid ByteString format")
+ }
+ updater.set(ordinal, byte_array)
+
+ case (MESSAGE, MapType(keyType, valueType, valueContainsNull)) =>
+ newMapWriter(protoType, protoPath, catalystPath, keyType, valueType, valueContainsNull)
+
+ case (MESSAGE, TimestampType) =>
+ (updater, ordinal, value) =>
+ val secondsField = protoType.getMessageType.getFields.get(0)
+ val nanoSecondsField = protoType.getMessageType.getFields.get(1)
+ val message = value.asInstanceOf[DynamicMessage]
+ val seconds = message.getField(secondsField).asInstanceOf[Long]
+ val nanoSeconds = message.getField(nanoSecondsField).asInstanceOf[Int]
+ val micros = DateTimeUtils.millisToMicros(seconds * 1000)
+ updater.setLong(ordinal, micros + TimeUnit.NANOSECONDS.toMicros(nanoSeconds))
+
+ case (MESSAGE, DayTimeIntervalType(startField, endField)) =>
+ (updater, ordinal, value) =>
+ val secondsField = protoType.getMessageType.getFields.get(0)
+ val nanoSecondsField = protoType.getMessageType.getFields.get(1)
+ val message = value.asInstanceOf[DynamicMessage]
+ val seconds = message.getField(secondsField).asInstanceOf[Long]
+ val nanoSeconds = message.getField(nanoSecondsField).asInstanceOf[Int]
+ val micros = DateTimeUtils.millisToMicros(seconds * 1000)
+ updater.setLong(ordinal, micros + TimeUnit.NANOSECONDS.toMicros(nanoSeconds))
+
+ case (MESSAGE, st: StructType) =>
+ val writeRecord = getRecordWriter(
+ protoType.getMessageType,
+ st,
+ protoPath,
+ catalystPath,
+ applyFilters = _ => false)
+ (updater, ordinal, value) =>
+ val row = new SpecificInternalRow(st)
+ writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage])
+ updater.set(ordinal, row)
+
+ case (MESSAGE, ArrayType(st: StructType, containsNull)) =>
+ newArrayWriter(protoType, protoPath, catalystPath, st, containsNull)
+
+ case (ENUM, StringType) =>
+ (updater, ordinal, value) => updater.set(ordinal, UTF8String.fromString(value.toString))
+
+ case _ => throw new IncompatibleSchemaException(incompatibleMsg)
+ }
+ }
+
+ private def getRecordWriter(
+ protoType: Descriptor,
+ catalystType: StructType,
+ protoPath: Seq[String],
+ catalystPath: Seq[String],
+ applyFilters: Int => Boolean): (CatalystDataUpdater, DynamicMessage) => Boolean = {
+
+ val protoSchemaHelper =
+ new ProtobufUtils.ProtoSchemaHelper(protoType, catalystType, protoPath, catalystPath)
+
+ // TODO revisit validation of protobuf-catalyst fields.
+ // protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true)
+
+ var i = 0
+ val (validFieldIndexes, fieldWriters) = protoSchemaHelper.matchedFields
+ .map { case ProtoMatchedField(catalystField, ordinal, protoField) =>
+ val baseWriter = newWriter(
+ protoField,
+ catalystField.dataType,
+ protoPath :+ protoField.getName,
+ catalystPath :+ catalystField.name)
+ val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => {
+ if (value == null) {
+ fieldUpdater.setNullAt(ordinal)
+ } else {
+ baseWriter(fieldUpdater, ordinal, value)
+ }
+ }
+ i += 1
+ (protoField, fieldWriter)
+ }
+ .toArray
+ .unzip
+
+ (fieldUpdater, record) => {
+ var i = 0
+ var skipRow = false
+ while (i < validFieldIndexes.length && !skipRow) {
+ val field = validFieldIndexes(i)
+ val value = if (field.isRepeated || field.hasDefaultValue || record.hasField(field)) {
+ record.getField(field)
+ } else null
+ fieldWriters(i)(fieldUpdater, value)
+ skipRow = applyFilters(i)
+ i += 1
+ }
+ skipRow
+ }
+ }
+
+ // TODO: All of the code below this line is same between protobuf and avro, it can be shared.
+ private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match {
+ case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length))
+ case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length))
+ case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length))
+ case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length))
+ case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length))
+ case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length))
+ case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length))
+ case _ => new GenericArrayData(new Array[Any](length))
+ }
+
+ /**
+ * A base interface for updating values inside catalyst data structure like `InternalRow` and
+ * `ArrayData`.
+ */
+ sealed trait CatalystDataUpdater {
+ def set(ordinal: Int, value: Any): Unit
+ def setNullAt(ordinal: Int): Unit = set(ordinal, null)
+ def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value)
+ def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value)
+ def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value)
+ def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value)
+ def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
+ def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
+ def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
+ def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value)
+ }
+
+ final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
+ override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value)
+ override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal)
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value)
+ override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value)
+ override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value)
+ override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value)
+ override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
+ override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
+ override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
+ override def setDecimal(ordinal: Int, value: Decimal): Unit =
+ row.setDecimal(ordinal, value, value.precision)
+ }
+
+ final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
+ override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value)
+ override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal)
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value)
+ override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value)
+ override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value)
+ override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value)
+ override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value)
+ override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value)
+ override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value)
+ override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value)
+ }
+
+}
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
new file mode 100644
index 0000000000000..5d9af92c5c077
--- /dev/null
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.protobuf
+
+import scala.collection.JavaConverters._
+
+import com.google.protobuf.{Duration, DynamicMessage, Timestamp}
+import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
+import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
+import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
+import org.apache.spark.sql.protobuf.utils.ProtobufUtils
+import org.apache.spark.sql.protobuf.utils.ProtobufUtils.{toFieldStr, ProtoMatchedField}
+import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
+import org.apache.spark.sql.types._
+
+/**
+ * A serializer to serialize data in catalyst format to data in Protobuf format.
+ */
+private[sql] class ProtobufSerializer(
+ rootCatalystType: DataType,
+ rootDescriptor: Descriptor,
+ nullable: Boolean)
+ extends Logging {
+
+ def serialize(catalystData: Any): Any = {
+ converter.apply(catalystData)
+ }
+
+ private val converter: Any => Any = {
+ val baseConverter =
+ try {
+ rootCatalystType match {
+ case st: StructType =>
+ newStructConverter(st, rootDescriptor, Nil, Nil).asInstanceOf[Any => Any]
+ }
+ } catch {
+ case ise: IncompatibleSchemaException =>
+ throw new IncompatibleSchemaException(
+ s"Cannot convert SQL type ${rootCatalystType.sql} to Protobuf type " +
+ s"${rootDescriptor.getName}.",
+ ise)
+ }
+ if (nullable) { (data: Any) =>
+ if (data == null) {
+ null
+ } else {
+ baseConverter.apply(data)
+ }
+ } else {
+ baseConverter
+ }
+ }
+
+ private type Converter = (SpecializedGetters, Int) => Any
+
+ private def newConverter(
+ catalystType: DataType,
+ fieldDescriptor: FieldDescriptor,
+ catalystPath: Seq[String],
+ protoPath: Seq[String]): Converter = {
+ val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " +
+ s"to Protobuf ${toFieldStr(protoPath)} because "
+ (catalystType, fieldDescriptor.getJavaType) match {
+ case (NullType, _) =>
+ (getter, ordinal) => null
+ case (BooleanType, BOOLEAN) =>
+ (getter, ordinal) => getter.getBoolean(ordinal)
+ case (ByteType, INT) =>
+ (getter, ordinal) => getter.getByte(ordinal).toInt
+ case (ShortType, INT) =>
+ (getter, ordinal) => getter.getShort(ordinal).toInt
+ case (IntegerType, INT) =>
+ (getter, ordinal) => {
+ getter.getInt(ordinal)
+ }
+ case (LongType, LONG) =>
+ (getter, ordinal) => getter.getLong(ordinal)
+ case (FloatType, FLOAT) =>
+ (getter, ordinal) => getter.getFloat(ordinal)
+ case (DoubleType, DOUBLE) =>
+ (getter, ordinal) => getter.getDouble(ordinal)
+ case (StringType, ENUM) =>
+ val enumSymbols: Set[String] =
+ fieldDescriptor.getEnumType.getValues.asScala.map(e => e.toString).toSet
+ (getter, ordinal) =>
+ val data = getter.getUTF8String(ordinal).toString
+ if (!enumSymbols.contains(data)) {
+ throw new IncompatibleSchemaException(
+ errorPrefix +
+ s""""$data" cannot be written since it's not defined in enum """ +
+ enumSymbols.mkString("\"", "\", \"", "\""))
+ }
+ fieldDescriptor.getEnumType.findValueByName(data)
+ case (StringType, STRING) =>
+ (getter, ordinal) => {
+ String.valueOf(getter.getUTF8String(ordinal))
+ }
+
+ case (BinaryType, BYTE_STRING) =>
+ (getter, ordinal) => getter.getBinary(ordinal)
+
+ case (DateType, INT) =>
+ (getter, ordinal) => getter.getInt(ordinal)
+
+ case (TimestampType, MESSAGE) =>
+ (getter, ordinal) =>
+ val millis = DateTimeUtils.microsToMillis(getter.getLong(ordinal))
+ Timestamp.newBuilder()
+ .setSeconds((millis / 1000))
+ .setNanos(((millis % 1000) * 1000000).toInt)
+ .build()
+
+ case (ArrayType(et, containsNull), _) =>
+ val elementConverter =
+ newConverter(et, fieldDescriptor, catalystPath :+ "element", protoPath :+ "element")
+ (getter, ordinal) => {
+ val arrayData = getter.getArray(ordinal)
+ val len = arrayData.numElements()
+ val result = new Array[Any](len)
+ var i = 0
+ while (i < len) {
+ if (containsNull && arrayData.isNullAt(i)) {
+ result(i) = null
+ } else {
+ result(i) = elementConverter(arrayData, i)
+ }
+ i += 1
+ }
+ // Protobuf writer is expecting a Java Collection, so we convert it into
+ // `ArrayList` backed by the specified array without data copying.
+ java.util.Arrays.asList(result: _*)
+ }
+
+ case (st: StructType, MESSAGE) =>
+ val structConverter =
+ newStructConverter(st, fieldDescriptor.getMessageType, catalystPath, protoPath)
+ val numFields = st.length
+ (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))
+
+ case (MapType(kt, vt, valueContainsNull), MESSAGE) =>
+ var keyField: FieldDescriptor = null
+ var valueField: FieldDescriptor = null
+ fieldDescriptor.getMessageType.getFields.asScala.map { field =>
+ field.getName match {
+ case "key" =>
+ keyField = field
+ case "value" =>
+ valueField = field
+ }
+ }
+
+ val keyConverter = newConverter(kt, keyField, catalystPath :+ "key", protoPath :+ "key")
+ val valueConverter =
+ newConverter(vt, valueField, catalystPath :+ "value", protoPath :+ "value")
+
+ (getter, ordinal) =>
+ val mapData = getter.getMap(ordinal)
+ val len = mapData.numElements()
+ val list = new java.util.ArrayList[DynamicMessage]()
+ val keyArray = mapData.keyArray()
+ val valueArray = mapData.valueArray()
+ var i = 0
+ while (i < len) {
+ val result = DynamicMessage.newBuilder(fieldDescriptor.getMessageType)
+ if (valueContainsNull && valueArray.isNullAt(i)) {
+ result.setField(keyField, keyConverter(keyArray, i))
+ result.setField(valueField, valueField.getDefaultValue)
+ } else {
+ result.setField(keyField, keyConverter(keyArray, i))
+ result.setField(valueField, valueConverter(valueArray, i))
+ }
+ list.add(result.build())
+ i += 1
+ }
+ list
+
+ case (DayTimeIntervalType(startField, endField), MESSAGE) =>
+ (getter, ordinal) =>
+ val dayTimeIntervalString =
+ IntervalUtils.toDayTimeIntervalString(getter.getLong(ordinal)
+ , ANSI_STYLE, startField, endField)
+ val calendarInterval = IntervalUtils.fromIntervalString(dayTimeIntervalString)
+
+ val millis = DateTimeUtils.microsToMillis(calendarInterval.microseconds)
+ val duration = Duration.newBuilder()
+ .setSeconds((millis / 1000))
+ .setNanos(((millis % 1000) * 1000000).toInt)
+
+ if (duration.getSeconds < 0 && duration.getNanos > 0) {
+ duration.setSeconds(duration.getSeconds + 1)
+ duration.setNanos(duration.getNanos - 1000000000)
+ } else if (duration.getSeconds > 0 && duration.getNanos < 0) {
+ duration.setSeconds(duration.getSeconds - 1)
+ duration.setNanos(duration.getNanos + 1000000000)
+ }
+ duration.build()
+
+ case _ =>
+ throw new IncompatibleSchemaException(
+ errorPrefix +
+ s"schema is incompatible (sqlType = ${catalystType.sql}, " +
+ s"protoType = ${fieldDescriptor.getJavaType})")
+ }
+ }
+
+ private def newStructConverter(
+ catalystStruct: StructType,
+ descriptor: Descriptor,
+ catalystPath: Seq[String],
+ protoPath: Seq[String]): InternalRow => DynamicMessage = {
+
+ val protoSchemaHelper =
+ new ProtobufUtils.ProtoSchemaHelper(descriptor, catalystStruct, protoPath, catalystPath)
+
+ protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false)
+ protoSchemaHelper.validateNoExtraRequiredProtoFields()
+
+ val (protoIndices, fieldConverters: Array[Converter]) = protoSchemaHelper.matchedFields
+ .map { case ProtoMatchedField(catalystField, _, protoField) =>
+ val converter = newConverter(
+ catalystField.dataType,
+ protoField,
+ catalystPath :+ catalystField.name,
+ protoPath :+ protoField.getName)
+ (protoField, converter)
+ }
+ .toArray
+ .unzip
+
+ val numFields = catalystStruct.length
+ row: InternalRow =>
+ val result = DynamicMessage.newBuilder(descriptor)
+ var i = 0
+ while (i < numFields) {
+ if (row.isNullAt(i)) {
+ if (!protoIndices(i).isRepeated() &&
+ protoIndices(i).getJavaType() != FieldDescriptor.JavaType.MESSAGE &&
+ protoIndices(i).isRequired()) {
+ result.setField(protoIndices(i), protoIndices(i).getDefaultValue())
+ }
+ } else {
+ result.setField(protoIndices(i), fieldConverters(i).apply(row, i))
+ }
+ i += 1
+ }
+ result.build()
+ }
+}
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
new file mode 100644
index 0000000000000..283d1ca8c412c
--- /dev/null
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.protobuf
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.Column
+
+// scalastyle:off: object.name
+object functions {
+// scalastyle:on: object.name
+
+ /**
+ * Converts a binary column of Protobuf format into its corresponding catalyst value. The
+ * specified schema must match actual schema of the read data, otherwise the behavior is
+ * undefined: it may fail or return arbitrary result. To deserialize the data with a compatible
+ * and evolved schema, the expected Protobuf schema can be set via the option protoSchema.
+ *
+ * @param data
+ * the binary column.
+ * @param descFilePath
+ * the protobuf descriptor in Message GeneratedMessageV3 format.
+ * @param messageName
+ * the protobuf message name to look for in descriptorFile.
+ * @since 3.4.0
+ */
+ @Experimental
+ def from_protobuf(
+ data: Column,
+ descFilePath: String,
+ messageName: String,
+ options: java.util.Map[String, String]): Column = {
+ new Column(
+ ProtobufDataToCatalyst(data.expr, descFilePath, messageName, options.asScala.toMap))
+ }
+
+ /**
+ * Converts a binary column of Protobuf format into its corresponding catalyst value. The
+ * specified schema must match actual schema of the read data, otherwise the behavior is
+ * undefined: it may fail or return arbitrary result. To deserialize the data with a compatible
+ * and evolved schema, the expected Protobuf schema can be set via the option protoSchema.
+ *
+ * @param data
+ * the binary column.
+ * @param descFilePath
+ * the protobuf descriptor in Message GeneratedMessageV3 format.
+ * @param messageName
+ * the protobuf MessageName to look for in descriptorFile.
+ * @since 3.4.0
+ */
+ @Experimental
+ def from_protobuf(data: Column, descFilePath: String, messageName: String): Column = {
+ new Column(ProtobufDataToCatalyst(data.expr, descFilePath, messageName, Map.empty))
+ }
+
+ /**
+ * Converts a column into binary of protobuf format.
+ *
+ * @param data
+ * the data column.
+ * @param descFilePath
+ * the protobuf descriptor in Message GeneratedMessageV3 format.
+ * @param messageName
+ * the protobuf MessageName to look for in descriptorFile.
+ * @since 3.4.0
+ */
+ @Experimental
+ def to_protobuf(data: Column, descFilePath: String, messageName: String): Column = {
+ new Column(CatalystDataToProtobuf(data.expr, descFilePath, messageName))
+ }
+}
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala
new file mode 100644
index 0000000000000..82cdc6b9c5816
--- /dev/null
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+package object protobuf {
+ protected[protobuf] object ScalaReflectionLock
+}
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
new file mode 100644
index 0000000000000..1cece0d7966e5
--- /dev/null
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.protobuf.utils
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.FileSourceOptions
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode}
+
+/**
+ * Options for Protobuf Reader and Writer stored in case insensitive manner.
+ */
+private[sql] class ProtobufOptions(
+ @transient val parameters: CaseInsensitiveMap[String],
+ @transient val conf: Configuration)
+ extends FileSourceOptions(parameters)
+ with Logging {
+
+ def this(parameters: Map[String, String], conf: Configuration) = {
+ this(CaseInsensitiveMap(parameters), conf)
+ }
+
+ val parseMode: ParseMode =
+ parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode)
+}
+
+private[sql] object ProtobufOptions {
+ def apply(parameters: Map[String, String]): ProtobufOptions = {
+ val hadoopConf = SparkSession.getActiveSession
+ .map(_.sessionState.newHadoopConf())
+ .getOrElse(new Configuration())
+ new ProtobufOptions(CaseInsensitiveMap(parameters), hadoopConf)
+ }
+}
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
new file mode 100644
index 0000000000000..5ad043142a2d2
--- /dev/null
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.protobuf.utils
+
+import java.io.{BufferedInputStream, FileInputStream, IOException}
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+
+import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException}
+import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
+import org.apache.spark.sql.types._
+
+private[sql] object ProtobufUtils extends Logging {
+
+ /** Wrapper for a pair of matched fields, one Catalyst and one corresponding Protobuf field. */
+ private[sql] case class ProtoMatchedField(
+ catalystField: StructField,
+ catalystPosition: Int,
+ fieldDescriptor: FieldDescriptor)
+
+ /**
+ * Helper class to perform field lookup/matching on Protobuf schemas.
+ *
+ * This will match `descriptor` against `catalystSchema`, attempting to find a matching field in
+ * the Protobuf descriptor for each field in the Catalyst schema and vice-versa, respecting
+ * settings for case sensitivity. The match results can be accessed using the getter methods.
+ *
+ * @param descriptor
+ * The descriptor in which to search for fields. Must be of type Descriptor.
+ * @param catalystSchema
+ * The Catalyst schema to use for matching.
+ * @param protoPath
+ * The seq of parent field names leading to `protoSchema`.
+ * @param catalystPath
+ * The seq of parent field names leading to `catalystSchema`.
+ */
+ class ProtoSchemaHelper(
+ descriptor: Descriptor,
+ catalystSchema: StructType,
+ protoPath: Seq[String],
+ catalystPath: Seq[String]) {
+ if (descriptor.getName == null) {
+ throw new IncompatibleSchemaException(
+ s"Attempting to treat ${descriptor.getName} as a RECORD, " +
+ s"but it was: ${descriptor.getContainingType}")
+ }
+
+ private[this] val protoFieldArray = descriptor.getFields.asScala.toArray
+ private[this] val fieldMap = descriptor.getFields.asScala
+ .groupBy(_.getName.toLowerCase(Locale.ROOT))
+ .mapValues(_.toSeq) // toSeq needed for scala 2.13
+
+ /** The fields which have matching equivalents in both Protobuf and Catalyst schemas. */
+ val matchedFields: Seq[ProtoMatchedField] = catalystSchema.zipWithIndex.flatMap {
+ case (sqlField, sqlPos) =>
+ getFieldByName(sqlField.name).map(ProtoMatchedField(sqlField, sqlPos, _))
+ }
+
+ /**
+ * Validate that there are no Catalyst fields which don't have a matching Protobuf field,
+ * throwing [[IncompatibleSchemaException]] if such extra fields are found. If
+ * `ignoreNullable` is false, consider nullable Catalyst fields to be eligible to be an extra
+ * field; otherwise, ignore nullable Catalyst fields when checking for extras.
+ */
+ def validateNoExtraCatalystFields(ignoreNullable: Boolean): Unit =
+ catalystSchema.fields.foreach { sqlField =>
+ if (getFieldByName(sqlField.name).isEmpty &&
+ (!ignoreNullable || !sqlField.nullable)) {
+ throw new IncompatibleSchemaException(
+ s"Cannot find ${toFieldStr(catalystPath :+ sqlField.name)} in Protobuf schema")
+ }
+ }
+
+ /**
+ * Validate that there are no Protobuf fields which don't have a matching Catalyst field,
+ * throwing [[IncompatibleSchemaException]] if such extra fields are found. Only required
+ * (non-nullable) fields are checked; nullable fields are ignored.
+ */
+ def validateNoExtraRequiredProtoFields(): Unit = {
+ val extraFields = protoFieldArray.toSet -- matchedFields.map(_.fieldDescriptor)
+ extraFields.filterNot(isNullable).foreach { extraField =>
+ throw new IncompatibleSchemaException(
+ s"Found ${toFieldStr(protoPath :+ extraField.getName())} in Protobuf schema " +
+ "but there is no match in the SQL schema")
+ }
+ }
+
+ /**
+ * Extract a single field from the contained Protobuf schema which has the desired field name,
+ * performing the matching with proper case sensitivity according to SQLConf.resolver.
+ *
+ * @param name
+ * The name of the field to search for.
+ * @return
+ * `Some(match)` if a matching Protobuf field is found, otherwise `None`.
+ */
+ private[protobuf] def getFieldByName(name: String): Option[FieldDescriptor] = {
+
+ // get candidates, ignoring case of field name
+ val candidates = fieldMap.getOrElse(name.toLowerCase(Locale.ROOT), Seq.empty)
+
+ // search candidates, taking into account case sensitivity settings
+ candidates.filter(f => SQLConf.get.resolver(f.getName(), name)) match {
+ case Seq(protoField) => Some(protoField)
+ case Seq() => None
+ case matches =>
+ throw new IncompatibleSchemaException(
+ s"Searching for '$name' in " +
+ s"Protobuf schema at ${toFieldStr(protoPath)} gave ${matches.size} matches. " +
+ s"Candidates: " + matches.map(_.getName()).mkString("[", ", ", "]"))
+ }
+ }
+ }
+
+ def buildDescriptor(descFilePath: String, messageName: String): Descriptor = {
+ val fileDescriptor: Descriptors.FileDescriptor = parseFileDescriptor(descFilePath)
+ var result: Descriptors.Descriptor = null;
+
+ for (descriptor <- fileDescriptor.getMessageTypes.asScala) {
+ if (descriptor.getName().equals(messageName)) {
+ result = descriptor
+ }
+ }
+
+ if (null == result) {
+ throw new RuntimeException("Unable to locate Message '" + messageName + "' in Descriptor");
+ }
+ result
+ }
+
+ def parseFileDescriptor(descFilePath: String): Descriptors.FileDescriptor = {
+ var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null
+ try {
+ val dscFile = new BufferedInputStream(new FileInputStream(descFilePath))
+ fileDescriptorSet = DescriptorProtos.FileDescriptorSet.parseFrom(dscFile)
+ } catch {
+ case ex: InvalidProtocolBufferException =>
+ // TODO move all the exceptions to core/src/main/resources/error/error-classes.json
+ throw new RuntimeException("Error parsing descriptor byte[] into Descriptor object", ex)
+ case ex: IOException =>
+ throw new RuntimeException(
+ "Error reading Protobuf descriptor file at path: " +
+ descFilePath,
+ ex)
+ }
+
+ val descriptorProto: DescriptorProtos.FileDescriptorProto = fileDescriptorSet.getFile(0)
+ try {
+ val fileDescriptor: Descriptors.FileDescriptor = Descriptors.FileDescriptor.buildFrom(
+ descriptorProto,
+ new Array[Descriptors.FileDescriptor](0))
+ if (fileDescriptor.getMessageTypes().isEmpty()) {
+ throw new RuntimeException("No MessageTypes returned, " + fileDescriptor.getName());
+ }
+ fileDescriptor
+ } catch {
+ case e: Descriptors.DescriptorValidationException =>
+ throw new RuntimeException("Error constructing FileDescriptor", e)
+ }
+ }
+
+ /**
+ * Convert a sequence of hierarchical field names (like `Seq(foo, bar)`) into a human-readable
+ * string representing the field, like "field 'foo.bar'". If `names` is empty, the string
+ * "top-level record" is returned.
+ */
+ private[protobuf] def toFieldStr(names: Seq[String]): String = names match {
+ case Seq() => "top-level record"
+ case n => s"field '${n.mkString(".")}'"
+ }
+
+ /** Return true if `fieldDescriptor` is optional. */
+ private[protobuf] def isNullable(fieldDescriptor: FieldDescriptor): Boolean =
+ !fieldDescriptor.isOptional
+
+}
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
new file mode 100644
index 0000000000000..e385b816abe70
--- /dev/null
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.protobuf.utils
+
+import scala.collection.JavaConverters._
+
+import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.protobuf.ScalaReflectionLock
+import org.apache.spark.sql.types._
+
+@DeveloperApi
+object SchemaConverters {
+
+ /**
+ * Internal wrapper for SQL data type and nullability.
+ *
+ * @since 3.4.0
+ */
+ case class SchemaType(dataType: DataType, nullable: Boolean)
+
+ /**
+ * Converts an Protobuf schema to a corresponding Spark SQL schema.
+ *
+ * @since 3.4.0
+ */
+ def toSqlType(descriptor: Descriptor): SchemaType = {
+ toSqlTypeHelper(descriptor)
+ }
+
+ def toSqlTypeHelper(descriptor: Descriptor): SchemaType = ScalaReflectionLock.synchronized {
+ SchemaType(
+ StructType(descriptor.getFields.asScala.flatMap(structFieldFor(_, Set.empty)).toSeq),
+ nullable = true)
+ }
+
+ def structFieldFor(
+ fd: FieldDescriptor,
+ existingRecordNames: Set[String]): Option[StructField] = {
+ import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
+ val dataType = fd.getJavaType match {
+ case INT => Some(IntegerType)
+ case LONG => Some(LongType)
+ case FLOAT => Some(FloatType)
+ case DOUBLE => Some(DoubleType)
+ case BOOLEAN => Some(BooleanType)
+ case STRING => Some(StringType)
+ case BYTE_STRING => Some(BinaryType)
+ case ENUM => Some(StringType)
+ case MESSAGE if fd.getMessageType.getName == "Duration" =>
+ Some(DayTimeIntervalType.defaultConcreteType)
+ case MESSAGE if fd.getMessageType.getName == "Timestamp" =>
+ Some(TimestampType)
+ case MESSAGE if fd.isRepeated && fd.getMessageType.getOptions.hasMapEntry =>
+ var keyType: DataType = NullType
+ var valueType: DataType = NullType
+ fd.getMessageType.getFields.forEach { field =>
+ field.getName match {
+ case "key" =>
+ keyType = structFieldFor(field, existingRecordNames).get.dataType
+ case "value" =>
+ valueType = structFieldFor(field, existingRecordNames).get.dataType
+ }
+ }
+ return Option(
+ StructField(
+ fd.getName,
+ MapType(keyType, valueType, valueContainsNull = false).defaultConcreteType,
+ nullable = false))
+ case MESSAGE =>
+ if (existingRecordNames.contains(fd.getFullName)) {
+ throw new IncompatibleSchemaException(s"""
+ |Found recursive reference in Protobuf schema, which can not be processed by Spark:
+ |${fd.toString()}""".stripMargin)
+ }
+ val newRecordNames = existingRecordNames + fd.getFullName
+
+ Option(
+ fd.getMessageType.getFields.asScala
+ .flatMap(structFieldFor(_, newRecordNames.toSet))
+ .toSeq)
+ .filter(_.nonEmpty)
+ .map(StructType.apply)
+ case _ =>
+ throw new IncompatibleSchemaException(
+ s"Cannot convert Protobuf type" +
+ s" ${fd.getJavaType}")
+ }
+ dataType.map(dt =>
+ StructField(
+ fd.getName,
+ if (fd.isRepeated) ArrayType(dt, containsNull = false) else dt,
+ nullable = !fd.isRequired && !fd.isRepeated))
+ }
+
+ private[protobuf] class IncompatibleSchemaException(msg: String, ex: Throwable = null)
+ extends Exception(msg, ex)
+}
diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc b/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc
new file mode 100644
index 0000000000000..59255b488a03d
--- /dev/null
+++ b/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc
@@ -0,0 +1,48 @@
+
+‰
+Cconnector/protobuf/src/test/resources/protobuf/catalyst_types.protoorg.apache.spark.sql.protobuf")
+
+BooleanMsg
+ bool_type (RboolType"+
+
+IntegerMsg
+
+int32_type (R int32Type",
+ DoubleMsg
+double_type (R
+doubleType")
+FloatMsg
+
+float_type (R floatType")
+BytesMsg
+
+bytes_type (R bytesType",
+ StringMsg
+string_type ( R
+stringType".
+Person
+name ( Rname
+age (Rage"n
+Bad
+col_0 (Rcol0
+col_1 (Rcol1
+col_2 ( Rcol2
+col_3 (Rcol3
+col_4 (Rcol4"q
+Actual
+col_0 ( Rcol0
+col_1 (Rcol1
+col_2 (Rcol2
+col_3 (Rcol3
+col_4 (Rcol4"
+oldConsumer
+key ( Rkey"5
+newProducer
+key ( Rkey
+value (Rvalue"t
+newConsumer
+key ( Rkey
+value (Rvalue=
+actual (2%.org.apache.spark.sql.protobuf.ActualRactual"
+oldProducer
+key ( RkeyBB
CatalystTypesbproto3
\ No newline at end of file
diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
new file mode 100644
index 0000000000000..54e6bc18df153
--- /dev/null
+++ b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
+// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/catalyst_types.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
+
+syntax = "proto3";
+
+package org.apache.spark.sql.protobuf;
+option java_outer_classname = "CatalystTypes";
+
+message BooleanMsg {
+ bool bool_type = 1;
+}
+message IntegerMsg {
+ int32 int32_type = 1;
+}
+message DoubleMsg {
+ double double_type = 1;
+}
+message FloatMsg {
+ float float_type = 1;
+}
+message BytesMsg {
+ bytes bytes_type = 1;
+}
+message StringMsg {
+ string string_type = 1;
+}
+
+message Person {
+ string name = 1;
+ int32 age = 2;
+}
+
+message Bad {
+ bytes col_0 = 1;
+ double col_1 = 2;
+ string col_2 = 3;
+ float col_3 = 4;
+ int64 col_4 = 5;
+}
+
+message Actual {
+ string col_0 = 1;
+ int32 col_1 = 2;
+ float col_2 = 3;
+ bool col_3 = 4;
+ double col_4 = 5;
+}
+
+message oldConsumer {
+ string key = 1;
+}
+
+message newProducer {
+ string key = 1;
+ int32 value = 2;
+}
+
+message newConsumer {
+ string key = 1;
+ int32 value = 2;
+ Actual actual = 3;
+}
+
+message oldProducer {
+ string key = 1;
+}
\ No newline at end of file
diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc
new file mode 100644
index 0000000000000..6e3a396727729
Binary files /dev/null and b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc differ
diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
new file mode 100644
index 0000000000000..f38c041b799ec
--- /dev/null
+++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// To compile and create test class:
+// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/functions_suite.proto
+// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/functions_suite.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/functions_suite.proto
+
+syntax = "proto3";
+
+package org.apache.spark.sql.protobuf;
+
+option java_outer_classname = "SimpleMessageProtos";
+
+message SimpleMessageJavaTypes {
+ int64 id = 1;
+ string string_value = 2;
+ int32 int32_value = 3;
+ int64 int64_value = 4;
+ double double_value = 5;
+ float float_value = 6;
+ bool bool_value = 7;
+ bytes bytes_value = 8;
+}
+
+message SimpleMessage {
+ int64 id = 1;
+ string string_value = 2;
+ int32 int32_value = 3;
+ uint32 uint32_value = 4;
+ sint32 sint32_value = 5;
+ fixed32 fixed32_value = 6;
+ sfixed32 sfixed32_value = 7;
+ int64 int64_value = 8;
+ uint64 uint64_value = 9;
+ sint64 sint64_value = 10;
+ fixed64 fixed64_value = 11;
+ sfixed64 sfixed64_value = 12;
+ double double_value = 13;
+ float float_value = 14;
+ bool bool_value = 15;
+ bytes bytes_value = 16;
+}
+
+message SimpleMessageRepeated {
+ string key = 1;
+ string value = 2;
+ enum NestedEnum {
+ ESTED_NOTHING = 0;
+ NESTED_FIRST = 1;
+ NESTED_SECOND = 2;
+ }
+ repeated string rstring_value = 3;
+ repeated int32 rint32_value = 4;
+ repeated bool rbool_value = 5;
+ repeated int64 rint64_value = 6;
+ repeated float rfloat_value = 7;
+ repeated double rdouble_value = 8;
+ repeated bytes rbytes_value = 9;
+ repeated NestedEnum rnested_enum = 10;
+}
+
+message BasicMessage {
+ int64 id = 1;
+ string string_value = 2;
+ int32 int32_value = 3;
+ int64 int64_value = 4;
+ double double_value = 5;
+ float float_value = 6;
+ bool bool_value = 7;
+ bytes bytes_value = 8;
+}
+
+message RepeatedMessage {
+ repeated BasicMessage basic_message = 1;
+}
+
+message SimpleMessageMap {
+ string key = 1;
+ string value = 2;
+ map string_mapdata = 3;
+ map int32_mapdata = 4;
+ map uint32_mapdata = 5;
+ map sint32_mapdata = 6;
+ map float32_mapdata = 7;
+ map sfixed32_mapdata = 8;
+ map int64_mapdata = 9;
+ map uint64_mapdata = 10;
+ map sint64_mapdata = 11;
+ map fixed64_mapdata = 12;
+ map sfixed64_mapdata = 13;
+ map double_mapdata = 14;
+ map float_mapdata = 15;
+ map bool_mapdata = 16;
+ map bytes_mapdata = 17;
+}
+
+message BasicEnumMessage {
+ enum BasicEnum {
+ NOTHING = 0;
+ FIRST = 1;
+ SECOND = 2;
+ }
+}
+
+message SimpleMessageEnum {
+ string key = 1;
+ string value = 2;
+ enum NestedEnum {
+ ESTED_NOTHING = 0;
+ NESTED_FIRST = 1;
+ NESTED_SECOND = 2;
+ }
+ BasicEnumMessage.BasicEnum basic_enum = 3;
+ NestedEnum nested_enum = 4;
+}
+
+
+message OtherExample {
+ string other = 1;
+}
+
+message IncludedExample {
+ string included = 1;
+ OtherExample other = 2;
+}
+
+message MultipleExample {
+ IncludedExample included_example = 1;
+}
+
+message recursiveA {
+ string keyA = 1;
+ recursiveB messageB = 2;
+}
+
+message recursiveB {
+ string keyB = 1;
+ recursiveA messageA = 2;
+}
+
+message recursiveC {
+ string keyC = 1;
+ recursiveD messageD = 2;
+}
+
+message recursiveD {
+ string keyD = 1;
+ repeated recursiveC messageC = 2;
+}
+
+message requiredMsg {
+ string key = 1;
+ int32 col_1 = 2;
+ string col_2 = 3;
+ int32 col_3 = 4;
+}
+
+// https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/timestamp.proto
+message Timestamp {
+ int64 seconds = 1;
+ int32 nanos = 2;
+}
+
+message timeStampMsg {
+ string key = 1;
+ Timestamp stmp = 2;
+}
+// https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/duration.proto
+message Duration {
+ int64 seconds = 1;
+ int32 nanos = 2;
+}
+
+message durationMsg {
+ string key = 1;
+ Duration duration = 2;
+}
\ No newline at end of file
diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.desc b/connector/protobuf/src/test/resources/protobuf/serde_suite.desc
new file mode 100644
index 0000000000000..3d1847eecc5c3
--- /dev/null
+++ b/connector/protobuf/src/test/resources/protobuf/serde_suite.desc
@@ -0,0 +1,27 @@
+
+²
+Fconnector/protobuf/src/test/resources/protobuf/proto_serde_suite.protoorg.apache.spark.sql.protobuf"D
+BasicMessage4
+foo (2".org.apache.spark.sql.protobuf.FooRfoo"
+Foo
+bar (Rbar"'
+MissMatchTypeInRoot
+foo (Rfoo"T
+FieldMissingInProto=
+foo (2+.org.apache.spark.sql.protobuf.MissingFieldRfoo"&
+MissingField
+barFoo (RbarFoo"\
+MissMatchTypeInDeepNested?
+top (2-.org.apache.spark.sql.protobuf.TypeMissNestedRtop"K
+TypeMissNested9
+foo (2'.org.apache.spark.sql.protobuf.TypeMissRfoo"
+TypeMiss
+bar (Rbar"_
+FieldMissingInSQLRoot4
+foo (2".org.apache.spark.sql.protobuf.FooRfoo
+boo (Rboo"O
+FieldMissingInSQLNested4
+foo (2".org.apache.spark.sql.protobuf.BazRfoo")
+Baz
+bar (Rbar
+baz (RbazBBSimpleMessageProtosbproto3
\ No newline at end of file
diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto
new file mode 100644
index 0000000000000..1e3065259aa02
--- /dev/null
+++ b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// To compile and create test class:
+// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/serde_suite.proto
+// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/serde_suite.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/serde_suite.proto
+
+syntax = "proto3";
+
+package org.apache.spark.sql.protobuf;
+option java_outer_classname = "SimpleMessageProtos";
+
+/* Clean Message*/
+message BasicMessage {
+ Foo foo = 1;
+}
+
+message Foo {
+ int32 bar = 1;
+}
+
+/* Field Type missMatch in root Message*/
+message MissMatchTypeInRoot {
+ int64 foo = 1;
+}
+
+/* Field bar missing from protobuf and Available in SQL*/
+message FieldMissingInProto {
+ MissingField foo = 1;
+}
+
+message MissingField {
+ int64 barFoo = 1;
+}
+
+/* Deep-nested field bar type missMatch Message*/
+message MissMatchTypeInDeepNested {
+ TypeMissNested top = 1;
+}
+
+message TypeMissNested {
+ TypeMiss foo = 1;
+}
+
+message TypeMiss {
+ int64 bar = 1;
+}
+
+/* Field boo missing from SQL root, but available in Protobuf root*/
+message FieldMissingInSQLRoot {
+ Foo foo = 1;
+ int32 boo = 2;
+}
+
+/* Field baz missing from SQL nested and available in Protobuf nested*/
+message FieldMissingInSQLNested {
+ Baz foo = 1;
+}
+
+message Baz {
+ int32 bar = 1;
+ int32 baz = 2;
+}
\ No newline at end of file
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
new file mode 100644
index 0000000000000..b730ebb4fea80
--- /dev/null
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.protobuf
+
+import com.google.protobuf.{ByteString, DynamicMessage, Message}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{RandomDataGenerator, Row}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, NoopFilters, OrderedFilters, StructFilters}
+import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
+import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters}
+import org.apache.spark.sql.sources.{EqualTo, Not}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+class ProtobufCatalystDataConversionSuite
+ extends SparkFunSuite
+ with SharedSparkSession
+ with ExpressionEvalHelper {
+
+ private def checkResult(
+ data: Literal,
+ descFilePath: String,
+ messageName: String,
+ expected: Any): Unit = {
+ checkEvaluation(
+ ProtobufDataToCatalyst(
+ CatalystDataToProtobuf(data, descFilePath, messageName),
+ descFilePath,
+ messageName,
+ Map.empty),
+ prepareExpectedResult(expected))
+ }
+
+ protected def checkUnsupportedRead(
+ data: Literal,
+ descFilePath: String,
+ actualSchema: String,
+ badSchema: String): Unit = {
+
+ val binary = CatalystDataToProtobuf(data, descFilePath, actualSchema)
+
+ intercept[Exception] {
+ ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> "FAILFAST")).eval()
+ }
+
+ val expected = {
+ val expectedSchema = ProtobufUtils.buildDescriptor(descFilePath, badSchema)
+ SchemaConverters.toSqlType(expectedSchema).dataType match {
+ case st: StructType =>
+ Row.fromSeq((0 until st.length).map { _ =>
+ null
+ })
+ case _ => null
+ }
+ }
+
+ checkEvaluation(
+ ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> "PERMISSIVE")),
+ expected)
+ }
+
+ protected def prepareExpectedResult(expected: Any): Any = expected match {
+ // Spark byte and short both map to Protobuf int
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult))
+ case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult))
+ case map: MapData =>
+ val keys = new GenericArrayData(
+ map.keyArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult))
+ val values = new GenericArrayData(
+ map.valueArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult))
+ new ArrayBasedMapData(keys, values)
+ case other => other
+ }
+
+ private val testingTypes = Seq(
+ StructType(StructField("int32_type", IntegerType, nullable = true) :: Nil),
+ StructType(StructField("double_type", DoubleType, nullable = true) :: Nil),
+ StructType(StructField("float_type", FloatType, nullable = true) :: Nil),
+ StructType(StructField("bytes_type", BinaryType, nullable = true) :: Nil),
+ StructType(StructField("string_type", StringType, nullable = true) :: Nil))
+
+ private val catalystTypesToProtoMessages: Map[DataType, String] = Map(
+ IntegerType -> "IntegerMsg",
+ DoubleType -> "DoubleMsg",
+ FloatType -> "FloatMsg",
+ BinaryType -> "BytesMsg",
+ StringType -> "StringMsg")
+
+ testingTypes.foreach { dt =>
+ val seed = 1 + scala.util.Random.nextInt((1024 - 1) + 1)
+ val filePath = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+ test(s"single $dt with seed $seed") {
+ val rand = new scala.util.Random(seed)
+ val data = RandomDataGenerator.forType(dt, rand = rand).get.apply()
+ val converter = CatalystTypeConverters.createToCatalystConverter(dt)
+ val input = Literal.create(converter(data), dt)
+
+ checkResult(
+ input,
+ filePath,
+ catalystTypesToProtoMessages(dt.fields(0).dataType),
+ input.eval())
+ }
+ }
+
+ private def checkDeserialization(
+ descFilePath: String,
+ messageName: String,
+ data: Message,
+ expected: Option[Any],
+ filters: StructFilters = new NoopFilters): Unit = {
+
+ val descriptor = ProtobufUtils.buildDescriptor(descFilePath, messageName)
+ val dataType = SchemaConverters.toSqlType(descriptor).dataType
+
+ val deserializer = new ProtobufDeserializer(descriptor, dataType, filters)
+
+ val dynMsg = DynamicMessage.parseFrom(descriptor, data.toByteArray)
+ val deserialized = deserializer.deserialize(dynMsg)
+ expected match {
+ case None => assert(deserialized.isEmpty)
+ case Some(d) =>
+ assert(checkResult(d, deserialized.get, dataType, exprNullable = false))
+ }
+ }
+
+ test("Handle unsupported input of message type") {
+ val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+ val actualSchema = StructType(
+ Seq(
+ StructField("col_0", StringType, nullable = false),
+ StructField("col_1", IntegerType, nullable = false),
+ StructField("col_2", FloatType, nullable = false),
+ StructField("col_3", BooleanType, nullable = false),
+ StructField("col_4", DoubleType, nullable = false)))
+
+ val seed = scala.util.Random.nextLong()
+ withClue(s"create random record with seed $seed") {
+ val data = RandomDataGenerator.randomRow(new scala.util.Random(seed), actualSchema)
+ val converter = CatalystTypeConverters.createToCatalystConverter(actualSchema)
+ val input = Literal.create(converter(data), actualSchema)
+ checkUnsupportedRead(input, testFileDesc, "Actual", "Bad")
+ }
+ }
+
+ test("filter push-down to Protobuf deserializer") {
+
+ val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+ val sqlSchema = new StructType()
+ .add("name", "string")
+ .add("age", "int")
+
+ val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "Person")
+ val dynamicMessage = DynamicMessage
+ .newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("name"), "Maxim")
+ .setField(descriptor.findFieldByName("age"), 39)
+ .build()
+
+ val expectedRow = Some(InternalRow(UTF8String.fromString("Maxim"), 39))
+ checkDeserialization(testFileDesc, "Person", dynamicMessage, expectedRow)
+ checkDeserialization(
+ testFileDesc,
+ "Person",
+ dynamicMessage,
+ expectedRow,
+ new OrderedFilters(Seq(EqualTo("age", 39)), sqlSchema))
+
+ checkDeserialization(
+ testFileDesc,
+ "Person",
+ dynamicMessage,
+ None,
+ new OrderedFilters(Seq(Not(EqualTo("name", "Maxim"))), sqlSchema))
+ }
+
+ test("ProtobufDeserializer with binary type") {
+
+ val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+ val bb = java.nio.ByteBuffer.wrap(Array[Byte](97, 48, 53))
+
+ val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg")
+
+ val dynamicMessage = DynamicMessage
+ .newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("bytes_type"), ByteString.copyFrom(bb))
+ .build()
+
+ val expected = InternalRow(Array[Byte](97, 48, 53))
+ checkDeserialization(testFileDesc, "BytesMsg", dynamicMessage, Some(expected))
+ }
+}
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
new file mode 100644
index 0000000000000..4e9bc1c1c287a
--- /dev/null
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -0,0 +1,615 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.protobuf
+
+import java.sql.Timestamp
+import java.time.Duration
+
+import scala.collection.JavaConverters._
+
+import com.google.protobuf.{ByteString, DynamicMessage}
+
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.functions.{lit, struct}
+import org.apache.spark.sql.protobuf.utils.ProtobufUtils
+import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{DayTimeIntervalType, IntegerType, StringType, StructField, StructType, TimestampType}
+
+class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Serializable {
+
+ import testImplicits._
+
+ val testFileDesc = testFile("protobuf/functions_suite.desc").replace("file:/", "/")
+
+ test("roundtrip in to_protobuf and from_protobuf - struct") {
+ val df = spark
+ .range(1, 10)
+ .select(struct(
+ $"id",
+ $"id".cast("string").as("string_value"),
+ $"id".cast("int").as("int32_value"),
+ $"id".cast("int").as("uint32_value"),
+ $"id".cast("int").as("sint32_value"),
+ $"id".cast("int").as("fixed32_value"),
+ $"id".cast("int").as("sfixed32_value"),
+ $"id".cast("long").as("int64_value"),
+ $"id".cast("long").as("uint64_value"),
+ $"id".cast("long").as("sint64_value"),
+ $"id".cast("long").as("fixed64_value"),
+ $"id".cast("long").as("sfixed64_value"),
+ $"id".cast("double").as("double_value"),
+ lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"),
+ lit(true).as("bool_value"),
+ lit("0".getBytes).as("bytes_value")).as("SimpleMessage"))
+ val protoStructDF = df.select(
+ functions.to_protobuf($"SimpleMessage", testFileDesc, "SimpleMessage").as("proto"))
+ val actualDf = protoStructDF.select(
+ functions.from_protobuf($"proto", testFileDesc, "SimpleMessage").as("proto.*"))
+ checkAnswer(actualDf, df)
+ }
+
+ test("roundtrip in from_protobuf and to_protobuf - Repeated") {
+ val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageRepeated")
+
+ val dynamicMessage = DynamicMessage
+ .newBuilder(descriptor)
+ .setField(descriptor.findFieldByName("key"), "key")
+ .setField(descriptor.findFieldByName("value"), "value")
+ .addRepeatedField(descriptor.findFieldByName("rbool_value"), false)
+ .addRepeatedField(descriptor.findFieldByName("rbool_value"), true)
+ .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092092.654d)
+ .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092093.654d)
+ .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10903.0f)
+ .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10902.0f)
+ .addRepeatedField(
+ descriptor.findFieldByName("rnested_enum"),
+ descriptor.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING"))
+ .addRepeatedField(
+ descriptor.findFieldByName("rnested_enum"),
+ descriptor.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST"))
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(
+ functions.from_protobuf($"value", testFileDesc, "SimpleMessageRepeated").as("value_from"))
+ val toProtoDF = fromProtoDF.select(
+ functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageRepeated").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions
+ .from_protobuf($"value_to", testFileDesc, "SimpleMessageRepeated")
+ .as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_protobuf and to_protobuf - Repeated Message Once") {
+ val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "RepeatedMessage")
+ val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage")
+
+ val basicMessage = DynamicMessage
+ .newBuilder(basicMessageDesc)
+ .setField(basicMessageDesc.findFieldByName("id"), 1111L)
+ .setField(basicMessageDesc.findFieldByName("string_value"), "value")
+ .setField(basicMessageDesc.findFieldByName("int32_value"), 12345)
+ .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L)
+ .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0d)
+ .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f)
+ .setField(basicMessageDesc.findFieldByName("bool_value"), true)
+ .setField(
+ basicMessageDesc.findFieldByName("bytes_value"),
+ ByteString.copyFromUtf8("ProtobufDeserializer"))
+ .build()
+
+ val dynamicMessage = DynamicMessage
+ .newBuilder(repeatedMessageDesc)
+ .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage)
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(
+ functions.from_protobuf($"value", testFileDesc, "RepeatedMessage").as("value_from"))
+ val toProtoDF = fromProtoDF.select(
+ functions.to_protobuf($"value_from", testFileDesc, "RepeatedMessage").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions.from_protobuf($"value_to", testFileDesc, "RepeatedMessage").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_protobuf and to_protobuf - Repeated Message Twice") {
+ val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "RepeatedMessage")
+ val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage")
+
+ val basicMessage1 = DynamicMessage
+ .newBuilder(basicMessageDesc)
+ .setField(basicMessageDesc.findFieldByName("id"), 1111L)
+ .setField(basicMessageDesc.findFieldByName("string_value"), "value1")
+ .setField(basicMessageDesc.findFieldByName("int32_value"), 12345)
+ .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L)
+ .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0d)
+ .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f)
+ .setField(basicMessageDesc.findFieldByName("bool_value"), true)
+ .setField(
+ basicMessageDesc.findFieldByName("bytes_value"),
+ ByteString.copyFromUtf8("ProtobufDeserializer1"))
+ .build()
+ val basicMessage2 = DynamicMessage
+ .newBuilder(basicMessageDesc)
+ .setField(basicMessageDesc.findFieldByName("id"), 1112L)
+ .setField(basicMessageDesc.findFieldByName("string_value"), "value2")
+ .setField(basicMessageDesc.findFieldByName("int32_value"), 12346)
+ .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L)
+ .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0d)
+ .setField(basicMessageDesc.findFieldByName("float_value"), 10903.0f)
+ .setField(basicMessageDesc.findFieldByName("bool_value"), false)
+ .setField(
+ basicMessageDesc.findFieldByName("bytes_value"),
+ ByteString.copyFromUtf8("ProtobufDeserializer2"))
+ .build()
+
+ val dynamicMessage = DynamicMessage
+ .newBuilder(repeatedMessageDesc)
+ .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage1)
+ .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage2)
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(
+ functions.from_protobuf($"value", testFileDesc, "RepeatedMessage").as("value_from"))
+ val toProtoDF = fromProtoDF.select(
+ functions.to_protobuf($"value_from", testFileDesc, "RepeatedMessage").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions.from_protobuf($"value_to", testFileDesc, "RepeatedMessage").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_protobuf and to_protobuf - Map") {
+ val messageMapDesc = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageMap")
+
+ val mapStr1 = DynamicMessage
+ .newBuilder(messageMapDesc.findNestedTypeByName("StringMapdataEntry"))
+ .setField(
+ messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"),
+ "string_key")
+ .setField(
+ messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"),
+ "value1")
+ .build()
+ val mapStr2 = DynamicMessage
+ .newBuilder(messageMapDesc.findNestedTypeByName("StringMapdataEntry"))
+ .setField(
+ messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"),
+ "string_key")
+ .setField(
+ messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"),
+ "value2")
+ .build()
+ val mapInt64 = DynamicMessage
+ .newBuilder(messageMapDesc.findNestedTypeByName("Int64MapdataEntry"))
+ .setField(
+ messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("key"),
+ 0x90000000000L)
+ .setField(
+ messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("value"),
+ 0x90000000001L)
+ .build()
+ val mapInt32 = DynamicMessage
+ .newBuilder(messageMapDesc.findNestedTypeByName("Int32MapdataEntry"))
+ .setField(
+ messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("key"),
+ 12345)
+ .setField(
+ messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("value"),
+ 54321)
+ .build()
+ val mapFloat = DynamicMessage
+ .newBuilder(messageMapDesc.findNestedTypeByName("FloatMapdataEntry"))
+ .setField(
+ messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("key"),
+ "float_key")
+ .setField(
+ messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("value"),
+ 109202.234f)
+ .build()
+ val mapDouble = DynamicMessage
+ .newBuilder(messageMapDesc.findNestedTypeByName("DoubleMapdataEntry"))
+ .setField(
+ messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("key"),
+ "double_key")
+ .setField(
+ messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("value"),
+ 109202.12d)
+ .build()
+ val mapBool = DynamicMessage
+ .newBuilder(messageMapDesc.findNestedTypeByName("BoolMapdataEntry"))
+ .setField(
+ messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("key"),
+ true)
+ .setField(
+ messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("value"),
+ false)
+ .build()
+
+ val dynamicMessage = DynamicMessage
+ .newBuilder(messageMapDesc)
+ .setField(messageMapDesc.findFieldByName("key"), "key")
+ .setField(messageMapDesc.findFieldByName("value"), "value")
+ .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr1)
+ .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr2)
+ .addRepeatedField(messageMapDesc.findFieldByName("int64_mapdata"), mapInt64)
+ .addRepeatedField(messageMapDesc.findFieldByName("int32_mapdata"), mapInt32)
+ .addRepeatedField(messageMapDesc.findFieldByName("float_mapdata"), mapFloat)
+ .addRepeatedField(messageMapDesc.findFieldByName("double_mapdata"), mapDouble)
+ .addRepeatedField(messageMapDesc.findFieldByName("bool_mapdata"), mapBool)
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(
+ functions.from_protobuf($"value", testFileDesc, "SimpleMessageMap").as("value_from"))
+ val toProtoDF = fromProtoDF.select(
+ functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageMap").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions.from_protobuf($"value_to", testFileDesc, "SimpleMessageMap").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_protobuf and to_protobuf - Enum") {
+ val messageEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageEnum")
+ val basicEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicEnumMessage")
+
+ val dynamicMessage = DynamicMessage
+ .newBuilder(messageEnumDesc)
+ .setField(messageEnumDesc.findFieldByName("key"), "key")
+ .setField(messageEnumDesc.findFieldByName("value"), "value")
+ .setField(
+ messageEnumDesc.findFieldByName("nested_enum"),
+ messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING"))
+ .setField(
+ messageEnumDesc.findFieldByName("nested_enum"),
+ messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST"))
+ .setField(
+ messageEnumDesc.findFieldByName("basic_enum"),
+ basicEnumDesc.findEnumTypeByName("BasicEnum").findValueByName("FIRST"))
+ .setField(
+ messageEnumDesc.findFieldByName("basic_enum"),
+ basicEnumDesc.findEnumTypeByName("BasicEnum").findValueByName("NOTHING"))
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(
+ functions.from_protobuf($"value", testFileDesc, "SimpleMessageEnum").as("value_from"))
+ val toProtoDF = fromProtoDF.select(
+ functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageEnum").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions.from_protobuf($"value_to", testFileDesc, "SimpleMessageEnum").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("roundtrip in from_protobuf and to_protobuf - Multiple Message") {
+ val messageMultiDesc = ProtobufUtils.buildDescriptor(testFileDesc, "MultipleExample")
+ val messageIncludeDesc = ProtobufUtils.buildDescriptor(testFileDesc, "IncludedExample")
+ val messageOtherDesc = ProtobufUtils.buildDescriptor(testFileDesc, "OtherExample")
+
+ val otherMessage = DynamicMessage
+ .newBuilder(messageOtherDesc)
+ .setField(messageOtherDesc.findFieldByName("other"), "other value")
+ .build()
+
+ val includeMessage = DynamicMessage
+ .newBuilder(messageIncludeDesc)
+ .setField(messageIncludeDesc.findFieldByName("included"), "included value")
+ .setField(messageIncludeDesc.findFieldByName("other"), otherMessage)
+ .build()
+
+ val dynamicMessage = DynamicMessage
+ .newBuilder(messageMultiDesc)
+ .setField(messageMultiDesc.findFieldByName("included_example"), includeMessage)
+ .build()
+
+ val df = Seq(dynamicMessage.toByteArray).toDF("value")
+ val fromProtoDF = df.select(
+ functions.from_protobuf($"value", testFileDesc, "MultipleExample").as("value_from"))
+ val toProtoDF = fromProtoDF.select(
+ functions.to_protobuf($"value_from", testFileDesc, "MultipleExample").as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ functions.from_protobuf($"value_to", testFileDesc, "MultipleExample").as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
+
+ test("Handle recursive fields in Protobuf schema, A->B->A") {
+ val schemaA = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveA")
+ val schemaB = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveB")
+
+ val messageBForA = DynamicMessage
+ .newBuilder(schemaB)
+ .setField(schemaB.findFieldByName("keyB"), "key")
+ .build()
+
+ val messageA = DynamicMessage
+ .newBuilder(schemaA)
+ .setField(schemaA.findFieldByName("keyA"), "key")
+ .setField(schemaA.findFieldByName("messageB"), messageBForA)
+ .build()
+
+ val messageB = DynamicMessage
+ .newBuilder(schemaB)
+ .setField(schemaB.findFieldByName("keyB"), "key")
+ .setField(schemaB.findFieldByName("messageA"), messageA)
+ .build()
+
+ val df = Seq(messageB.toByteArray).toDF("messageB")
+
+ val e = intercept[IncompatibleSchemaException] {
+ df.select(
+ functions.from_protobuf($"messageB", testFileDesc, "recursiveB").as("messageFromProto"))
+ .show()
+ }
+ val expectedMessage = s"""
+ |Found recursive reference in Protobuf schema, which can not be processed by Spark:
+ |org.apache.spark.sql.protobuf.recursiveB.messageA""".stripMargin
+ assert(e.getMessage == expectedMessage)
+ }
+
+ test("Handle recursive fields in Protobuf schema, C->D->Array(C)") {
+ val schemaC = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveC")
+ val schemaD = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveD")
+
+ val messageDForC = DynamicMessage
+ .newBuilder(schemaD)
+ .setField(schemaD.findFieldByName("keyD"), "key")
+ .build()
+
+ val messageC = DynamicMessage
+ .newBuilder(schemaC)
+ .setField(schemaC.findFieldByName("keyC"), "key")
+ .setField(schemaC.findFieldByName("messageD"), messageDForC)
+ .build()
+
+ val messageD = DynamicMessage
+ .newBuilder(schemaD)
+ .setField(schemaD.findFieldByName("keyD"), "key")
+ .addRepeatedField(schemaD.findFieldByName("messageC"), messageC)
+ .build()
+
+ val df = Seq(messageD.toByteArray).toDF("messageD")
+
+ val e = intercept[IncompatibleSchemaException] {
+ df.select(
+ functions.from_protobuf($"messageD", testFileDesc, "recursiveD").as("messageFromProto"))
+ .show()
+ }
+ val expectedMessage =
+ s"""
+ |Found recursive reference in Protobuf schema, which can not be processed by Spark:
+ |org.apache.spark.sql.protobuf.recursiveD.messageC""".stripMargin
+ assert(e.getMessage == expectedMessage)
+ }
+
+ test("Handle extra fields : oldProducer -> newConsumer") {
+ val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+ val oldProducer = ProtobufUtils.buildDescriptor(testFileDesc, "oldProducer")
+ val newConsumer = ProtobufUtils.buildDescriptor(testFileDesc, "newConsumer")
+
+ val oldProducerMessage = DynamicMessage
+ .newBuilder(oldProducer)
+ .setField(oldProducer.findFieldByName("key"), "key")
+ .build()
+
+ val df = Seq(oldProducerMessage.toByteArray).toDF("oldProducerData")
+ val fromProtoDf = df.select(
+ functions
+ .from_protobuf($"oldProducerData", testFileDesc, "newConsumer")
+ .as("fromProto"))
+
+ val toProtoDf = fromProtoDf.select(
+ functions
+ .to_protobuf($"fromProto", testFileDesc, "newConsumer")
+ .as("toProto"))
+
+ val toProtoDfToFromProtoDf = toProtoDf.select(
+ functions
+ .from_protobuf($"toProto", testFileDesc, "newConsumer")
+ .as("toProtoToFromProto"))
+
+ val actualFieldNames =
+ toProtoDfToFromProtoDf.select("toProtoToFromProto.*").schema.fields.toSeq.map(f => f.name)
+ newConsumer.getFields.asScala.map { f =>
+ {
+ assert(actualFieldNames.contains(f.getName))
+
+ }
+ }
+ assert(
+ toProtoDfToFromProtoDf.select("toProtoToFromProto.value").take(1).toSeq(0).get(0) == null)
+ assert(
+ toProtoDfToFromProtoDf.select("toProtoToFromProto.actual.*").take(1).toSeq(0).get(0) == null)
+ }
+
+ test("Handle extra fields : newProducer -> oldConsumer") {
+ val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+ val newProducer = ProtobufUtils.buildDescriptor(testFileDesc, "newProducer")
+ val oldConsumer = ProtobufUtils.buildDescriptor(testFileDesc, "oldConsumer")
+
+ val newProducerMessage = DynamicMessage
+ .newBuilder(newProducer)
+ .setField(newProducer.findFieldByName("key"), "key")
+ .setField(newProducer.findFieldByName("value"), 1)
+ .build()
+
+ val df = Seq(newProducerMessage.toByteArray).toDF("newProducerData")
+ val fromProtoDf = df.select(
+ functions
+ .from_protobuf($"newProducerData", testFileDesc, "oldConsumer")
+ .as("oldConsumerProto"))
+
+ val expectedFieldNames = oldConsumer.getFields.asScala.map(f => f.getName)
+ fromProtoDf.select("oldConsumerProto.*").schema.fields.toSeq.map { f =>
+ {
+ assert(expectedFieldNames.contains(f.name))
+ }
+ }
+ }
+
+ test("roundtrip in to_protobuf and from_protobuf - with nulls") {
+ val schema = StructType(
+ StructField("requiredMsg",
+ StructType(
+ StructField("key", StringType, nullable = false) ::
+ StructField("col_1", IntegerType, nullable = true) ::
+ StructField("col_2", StringType, nullable = false) ::
+ StructField("col_3", IntegerType, nullable = true) :: Nil
+ ),
+ nullable = true
+ ) :: Nil
+ )
+ val inputDf = spark.createDataFrame(
+ spark.sparkContext.parallelize(Seq(
+ Row(Row("key1", null, "value2", null))
+ )),
+ schema
+ )
+ val toProtobuf = inputDf.select(
+ functions.to_protobuf($"requiredMsg", testFileDesc, "requiredMsg")
+ .as("to_proto"))
+
+ val binary = toProtobuf.take(1).toSeq(0).get(0).asInstanceOf[Array[Byte]]
+
+ val messageDescriptor = ProtobufUtils.buildDescriptor(testFileDesc, "requiredMsg")
+ val actualMessage = DynamicMessage.parseFrom(messageDescriptor, binary)
+
+ assert(actualMessage.getField(messageDescriptor.findFieldByName("key"))
+ == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0))
+ assert(actualMessage.getField(messageDescriptor.findFieldByName("col_2"))
+ == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0))
+ assert(actualMessage.getField(messageDescriptor.findFieldByName("col_1")) == 0)
+ assert(actualMessage.getField(messageDescriptor.findFieldByName("col_3")) == 0)
+
+ val fromProtoDf = toProtobuf.select(
+ functions.from_protobuf($"to_proto", testFileDesc, "requiredMsg") as 'from_proto)
+
+ assert(fromProtoDf.select("from_proto.key").take(1).toSeq(0).get(0)
+ == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0))
+ assert(fromProtoDf.select("from_proto.col_2").take(1).toSeq(0).get(0)
+ == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0))
+ assert(fromProtoDf.select("from_proto.col_1").take(1).toSeq(0).get(0) == null)
+ assert(fromProtoDf.select("from_proto.col_3").take(1).toSeq(0).get(0) == null)
+ }
+
+ test("from_protobuf filter to_protobuf") {
+ val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage")
+
+ val basicMessage = DynamicMessage
+ .newBuilder(basicMessageDesc)
+ .setField(basicMessageDesc.findFieldByName("id"), 1111L)
+ .setField(basicMessageDesc.findFieldByName("string_value"), "slam")
+ .setField(basicMessageDesc.findFieldByName("int32_value"), 12345)
+ .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L)
+ .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0d)
+ .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f)
+ .setField(basicMessageDesc.findFieldByName("bool_value"), true)
+ .setField(
+ basicMessageDesc.findFieldByName("bytes_value"),
+ ByteString.copyFromUtf8("ProtobufDeserializer"))
+ .build()
+
+ val df = Seq(basicMessage.toByteArray).toDF("value")
+ val resultFrom = df
+ .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") as 'sample)
+ .where("sample.string_value == \"slam\"")
+
+ val resultToFrom = resultFrom
+ .select(functions.to_protobuf($"sample", testFileDesc, "BasicMessage") as 'value)
+ .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") as 'sample)
+ .where("sample.string_value == \"slam\"")
+
+ assert(resultFrom.except(resultToFrom).isEmpty)
+ }
+
+ test("Handle TimestampType between to_protobuf and from_protobuf") {
+ val schema = StructType(
+ StructField("timeStampMsg",
+ StructType(
+ StructField("key", StringType, nullable = true) ::
+ StructField("stmp", TimestampType, nullable = true) :: Nil
+ ),
+ nullable = true
+ ) :: Nil
+ )
+
+ val inputDf = spark.createDataFrame(
+ spark.sparkContext.parallelize(Seq(
+ Row(Row("key1", Timestamp.valueOf("2016-05-09 10:12:43.999")))
+ )),
+ schema
+ )
+
+ val toProtoDf = inputDf
+ .select(functions.to_protobuf($"timeStampMsg", testFileDesc, "timeStampMsg") as 'to_proto)
+
+ val fromProtoDf = toProtoDf
+ .select(functions.from_protobuf($"to_proto", testFileDesc, "timeStampMsg") as 'timeStampMsg)
+ fromProtoDf.show(truncate = false)
+
+ val actualFields = fromProtoDf.schema.fields.toList
+ val expectedFields = inputDf.schema.fields.toList
+
+ assert(actualFields.size === expectedFields.size)
+ assert(actualFields === expectedFields)
+ assert(fromProtoDf.select("timeStampMsg.key").take(1).toSeq(0).get(0)
+ === inputDf.select("timeStampMsg.key").take(1).toSeq(0).get(0))
+ assert(fromProtoDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0)
+ === inputDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0))
+ }
+
+ test("Handle DayTimeIntervalType between to_protobuf and from_protobuf") {
+ val schema = StructType(
+ StructField("durationMsg",
+ StructType(
+ StructField("key", StringType, nullable = true) ::
+ StructField("duration",
+ DayTimeIntervalType.defaultConcreteType, nullable = true) :: Nil
+ ),
+ nullable = true
+ ) :: Nil
+ )
+
+ val inputDf = spark.createDataFrame(
+ spark.sparkContext.parallelize(Seq(
+ Row(Row("key1",
+ Duration.ofDays(1).plusHours(2).plusMinutes(3).plusSeconds(4)
+ ))
+ )),
+ schema
+ )
+
+ val toProtoDf = inputDf
+ .select(functions.to_protobuf($"durationMsg", testFileDesc, "durationMsg") as 'to_proto)
+
+ val fromProtoDf = toProtoDf
+ .select(functions.from_protobuf($"to_proto", testFileDesc, "durationMsg") as 'durationMsg)
+
+ val actualFields = fromProtoDf.schema.fields.toList
+ val expectedFields = inputDf.schema.fields.toList
+
+ assert(actualFields.size === expectedFields.size)
+ assert(actualFields === expectedFields)
+ assert(fromProtoDf.select("durationMsg.key").take(1).toSeq(0).get(0)
+ === inputDf.select("durationMsg.key").take(1).toSeq(0).get(0))
+ assert(fromProtoDf.select("durationMsg.duration").take(1).toSeq(0).get(0)
+ === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0))
+
+ }
+}
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
new file mode 100644
index 0000000000000..37c59743e7714
--- /dev/null
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.protobuf
+
+import com.google.protobuf.Descriptors.Descriptor
+import com.google.protobuf.DynamicMessage
+
+import org.apache.spark.sql.catalyst.NoopFilters
+import org.apache.spark.sql.protobuf.utils.ProtobufUtils
+import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+/**
+ * Tests for [[ProtobufSerializer]] and [[ProtobufDeserializer]] with a more specific focus on
+ * those classes.
+ */
+class ProtobufSerdeSuite extends SharedSparkSession {
+
+ import ProtoSerdeSuite._
+ import ProtoSerdeSuite.MatchType._
+
+ val testFileDesc = testFile("protobuf/serde_suite.desc").replace("file:/", "/")
+
+ test("Test basic conversion") {
+ withFieldMatchType { fieldMatch =>
+ val (top, nest) = fieldMatch match {
+ case BY_NAME => ("foo", "bar")
+ }
+ val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage")
+
+ val dynamicMessageFoo = DynamicMessage
+ .newBuilder(protoFile.getFile.findMessageTypeByName("Foo"))
+ .setField(protoFile.getFile.findMessageTypeByName("Foo").findFieldByName("bar"), 10902)
+ .build()
+
+ val dynamicMessage = DynamicMessage
+ .newBuilder(protoFile)
+ .setField(protoFile.findFieldByName("foo"), dynamicMessageFoo)
+ .build()
+
+ val serializer = Serializer.create(CATALYST_STRUCT, protoFile, fieldMatch)
+ val deserializer = Deserializer.create(CATALYST_STRUCT, protoFile, fieldMatch)
+
+ assert(
+ serializer.serialize(deserializer.deserialize(dynamicMessage).get) === dynamicMessage)
+ }
+ }
+
+ test("Fail to convert with field type mismatch") {
+ val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "MissMatchTypeInRoot")
+
+ withFieldMatchType { fieldMatch =>
+ assertFailedConversionMessage(
+ protoFile,
+ Deserializer,
+ fieldMatch,
+ "Cannot convert Protobuf field 'foo' to SQL field 'foo' because schema is incompatible " +
+ s"(protoType = org.apache.spark.sql.protobuf.MissMatchTypeInRoot.foo " +
+ s"LABEL_OPTIONAL LONG INT64, sqlType = ${CATALYST_STRUCT.head.dataType.sql})".stripMargin)
+
+ assertFailedConversionMessage(
+ protoFile,
+ Serializer,
+ fieldMatch,
+ s"Cannot convert SQL field 'foo' to Protobuf field 'foo' because schema is incompatible " +
+ s"""(sqlType = ${CATALYST_STRUCT.head.dataType.sql}, protoType = LONG)""")
+ }
+ }
+
+ test("Fail to convert with missing nested Protobuf fields for serializer") {
+ val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInProto")
+
+ val nonnullCatalyst = new StructType()
+ .add("foo", new StructType().add("bar", IntegerType, nullable = false))
+
+ // serialize fails whether or not 'bar' is nullable
+ val byNameMsg = "Cannot find field 'foo.bar' in Protobuf schema"
+ assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg)
+ assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg, nonnullCatalyst)
+ }
+
+ test("Fail to convert with deeply nested field type mismatch") {
+ val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "MissMatchTypeInDeepNested")
+ val catalyst = new StructType().add("top", CATALYST_STRUCT)
+
+ withFieldMatchType { fieldMatch =>
+ assertFailedConversionMessage(
+ protoFile,
+ Deserializer,
+ fieldMatch,
+ s"Cannot convert Protobuf field 'top.foo.bar' to SQL field 'top.foo.bar' because schema " +
+ s"is incompatible (protoType = org.apache.spark.sql.protobuf.TypeMiss.bar " +
+ s"LABEL_OPTIONAL LONG INT64, sqlType = INT)".stripMargin,
+ catalyst)
+
+ assertFailedConversionMessage(
+ protoFile,
+ Serializer,
+ fieldMatch,
+ "Cannot convert SQL field 'top.foo.bar' to Protobuf field 'top.foo.bar' because schema " +
+ """is incompatible (sqlType = INT, protoType = LONG)""",
+ catalyst)
+ }
+ }
+
+ test("Fail to convert with missing Catalyst fields") {
+ val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLRoot")
+
+ // serializing with extra fails if extra field is missing in SQL Schema
+ assertFailedConversionMessage(
+ protoFile,
+ Serializer,
+ BY_NAME,
+ "Found field 'boo' in Protobuf schema but there is no match in the SQL schema")
+
+ /* deserializing should work regardless of whether the extra field is missing
+ in SQL Schema or not */
+ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _))
+ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _))
+
+ val protoNestedFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLNested")
+
+ // serializing with extra fails if extra field is missing in SQL Schema
+ assertFailedConversionMessage(
+ protoNestedFile,
+ Serializer,
+ BY_NAME,
+ "Found field 'foo.baz' in Protobuf schema but there is no match in the SQL schema")
+
+ /* deserializing should work regardless of whether the extra field is missing
+ in SQL Schema or not */
+ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _))
+ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _))
+ }
+
+ /**
+ * Attempt to convert `catalystSchema` to `protoSchema` (or vice-versa if `deserialize` is
+ * true), assert that it fails, and assert that the _cause_ of the thrown exception has a
+ * message matching `expectedCauseMessage`.
+ */
+ private def assertFailedConversionMessage(
+ protoSchema: Descriptor,
+ serdeFactory: SerdeFactory[_],
+ fieldMatchType: MatchType,
+ expectedCauseMessage: String,
+ catalystSchema: StructType = CATALYST_STRUCT): Unit = {
+ val e = intercept[IncompatibleSchemaException] {
+ serdeFactory.create(catalystSchema, protoSchema, fieldMatchType)
+ }
+ val expectMsg = serdeFactory match {
+ case Deserializer =>
+ s"Cannot convert Protobuf type ${protoSchema.getName} to SQL type ${catalystSchema.sql}."
+ case Serializer =>
+ s"Cannot convert SQL type ${catalystSchema.sql} to Protobuf type ${protoSchema.getName}."
+ }
+
+ assert(e.getMessage === expectMsg)
+ assert(e.getCause.getMessage === expectedCauseMessage)
+ }
+
+ def withFieldMatchType(f: MatchType => Unit): Unit = {
+ MatchType.values.foreach { fieldMatchType =>
+ withClue(s"fieldMatchType == $fieldMatchType") {
+ f(fieldMatchType)
+ }
+ }
+ }
+}
+
+object ProtoSerdeSuite {
+
+ val CATALYST_STRUCT =
+ new StructType().add("foo", new StructType().add("bar", IntegerType))
+
+ /**
+ * Specifier for type of field matching to be used for easy creation of tests that do by-name
+ * field matching.
+ */
+ object MatchType extends Enumeration {
+ type MatchType = Value
+ val BY_NAME = Value
+ }
+
+ import MatchType._
+
+ /**
+ * Specifier for type of serde to be used for easy creation of tests that do both serialization
+ * and deserialization.
+ */
+ sealed trait SerdeFactory[T] {
+ def create(sqlSchema: StructType, descriptor: Descriptor, fieldMatchType: MatchType): T
+ }
+
+ object Serializer extends SerdeFactory[ProtobufSerializer] {
+ override def create(
+ sql: StructType,
+ descriptor: Descriptor,
+ matchType: MatchType): ProtobufSerializer = new ProtobufSerializer(sql, descriptor, false)
+ }
+
+ object Deserializer extends SerdeFactory[ProtobufDeserializer] {
+ override def create(
+ sql: StructType,
+ descriptor: Descriptor,
+ matchType: MatchType): ProtobufDeserializer =
+ new ProtobufDeserializer(descriptor, sql, new NoopFilters)
+ }
+}
diff --git a/pom.xml b/pom.xml
index f82546e4f3e9c..7258f970bab7c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -101,6 +101,7 @@
connector/kafka-0-10-sql
connector/avro
connector/connect
+ connector/protobuf
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 1de8bc6a47ded..15fa3a3143b60 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -45,8 +45,8 @@ object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
- val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro) = Seq(
- "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro"
+ val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro, protobuf) = Seq(
+ "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro", "protobuf"
).map(ProjectRef(buildLocation, _))
val streamingProjects@Seq(streaming, streamingKafka010) =
@@ -59,7 +59,7 @@ object BuildCommons {
) = Seq(
"core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe",
"tags", "sketch", "kvstore"
- ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connect)
+ ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connect) ++ Seq(protobuf)
val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn,
sparkGangliaLgpl, streamingKinesisAsl,
@@ -390,7 +390,7 @@ object SparkBuild extends PomBuild {
val mimaProjects = allProjects.filterNot { x =>
Seq(
spark, hive, hiveThriftServer, repl, networkCommon, networkShuffle, networkYarn,
- unsafe, tags, tokenProviderKafka010, sqlKafka010, connect
+ unsafe, tags, tokenProviderKafka010, sqlKafka010, connect, protobuf
).contains(x)
}
@@ -433,6 +433,9 @@ object SparkBuild extends PomBuild {
enable(SparkConnect.settings)(connect)
+ /* Connector/proto settings */
+ enable(SparkProtobuf.settings)(protobuf)
+
// SPARK-14738 - Remove docker tests from main Spark build
// enable(DockerIntegrationTests.settings)(dockerIntegrationTests)
@@ -662,6 +665,48 @@ object SparkConnect {
)
}
+object SparkProtobuf {
+
+ import BuildCommons.protoVersion
+
+ private val shadePrefix = "org.sparkproject.spark-protobuf"
+ val shadeJar = taskKey[Unit]("Shade the Jars")
+
+ lazy val settings = Seq(
+ // Setting version for the protobuf compiler. This has to be propagated to every sub-project
+ // even if the project is not using it.
+ PB.protocVersion := BuildCommons.protoVersion,
+
+ // For some reason the resolution from the imported Maven build does not work for some
+ // of these dependendencies that we need to shade later on.
+ libraryDependencies ++= Seq(
+ "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf"
+ ),
+
+ dependencyOverrides ++= Seq(
+ "com.google.protobuf" % "protobuf-java" % protoVersion
+ ),
+
+ (Compile / PB.targets) := Seq(
+ PB.gens.java -> (Compile / sourceManaged).value,
+ ),
+
+ (assembly / test) := false,
+
+ (assembly / logLevel) := Level.Info,
+
+ (assembly / assemblyShadeRules) := Seq(
+ ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.spark-protobuf.protobuf.@1").inAll,
+ ),
+
+ (assembly / assemblyMergeStrategy) := {
+ case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") => MergeStrategy.discard
+ // Drop all proto files that are not needed as artifacts of the build.
+ case m if m.toLowerCase(Locale.ROOT).endsWith(".proto") => MergeStrategy.discard
+ case _ => MergeStrategy.first
+ },
+ )
+}
object Unsafe {
lazy val settings = Seq(
// This option is needed to suppress warnings from sun.misc.Unsafe usage
@@ -1107,10 +1152,10 @@ object Unidoc {
(ScalaUnidoc / unidoc / unidocProjectFilter) :=
inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes,
- yarn, tags, streamingKafka010, sqlKafka010, connect),
+ yarn, tags, streamingKafka010, sqlKafka010, connect, protobuf),
(JavaUnidoc / unidoc / unidocProjectFilter) :=
inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes,
- yarn, tags, streamingKafka010, sqlKafka010, connect),
+ yarn, tags, streamingKafka010, sqlKafka010, connect, protobuf),
(ScalaUnidoc / unidoc / unidocAllClasspaths) := {
ignoreClasspaths((ScalaUnidoc / unidoc / unidocAllClasspaths).value)
@@ -1196,6 +1241,7 @@ object CopyDependencies {
// produce the shaded Jar which happens automatically in the case of Maven.
// Later, when the dependencies are copied, we manually copy the shaded Jar only.
val fid = (LocalProject("connect") / assembly).value
+ val fidProtobuf = (LocalProject("protobuf")/assembly).value
(Compile / dependencyClasspath).value.map(_.data)
.filter { jar => jar.isFile() }
@@ -1208,6 +1254,9 @@ object CopyDependencies {
if (jar.getName.contains("spark-connect") &&
!SbtPomKeys.profiles.value.contains("noshade-connect")) {
Files.copy(fid.toPath, destJar.toPath)
+ } else if (jar.getName.contains("spark-protobuf") &&
+ !SbtPomKeys.profiles.value.contains("noshade-protobuf")) {
+ Files.copy(fid.toPath, destJar.toPath)
} else {
Files.copy(jar.toPath(), destJar.toPath())
}