diff --git a/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/KotlinReflection.scala b/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/KotlinReflection.scala index 774e0bc0..c847b16a 100644 --- a/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/KotlinReflection.scala +++ b/core/2.4/src/main/scala/org/apache/spark/sql/catalyst/KotlinReflection.scala @@ -22,7 +22,6 @@ import java.lang.reflect.Type import java.lang.{Iterable => JIterable} import java.time.LocalDate import java.util.{Iterator => JIterator, List => JList, Map => JMap} - import com.google.common.reflect.TypeToken import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ @@ -399,10 +398,38 @@ object KotlinReflection { getPath, customCollectionCls = Some(predefinedDt.get.cls)) + case StructType(elementType: Array[StructField]) => + val cls = t.cls + + val arguments = elementType.map { field => + val dataType = field.dataType.asInstanceOf[DataTypeWithClass] + val nullable = dataType.nullable + val clsName = dataType.cls.getName + val fieldName = field.asInstanceOf[KStructField].delegate.name + val newPath = addToPath(fieldName) + + deserializerFor( + TypeToken.of(dataType.cls), + Some(newPath), + Some(dataType).filter(_.isInstanceOf[ComplexWrapper]) + ) + } + val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) + + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(cls)), + newInstance + ) + } else { + newInstance + } case _ => throw new UnsupportedOperationException( - s"No Encoder found for $typeToken") + s"No Encoder found for $typeToken in deserializerFor\n" + path) } } @@ -608,8 +635,34 @@ object KotlinReflection { case ArrayType(elementType, _) => toCatalystArray(inputObject, TypeToken.of(elementType.asInstanceOf[DataTypeWithClass].cls), Some(elementType.asInstanceOf[DataTypeWithClass])) + case StructType(elementType: Array[StructField]) => + val cls = otherTypeWrapper.cls + val names = elementType.map(_.name) + + val beanInfo = Introspector.getBeanInfo(cls) + val methods = beanInfo.getMethodDescriptors.filter(it => names.contains(it.getName)) + + val fields = elementType.map { structField => + + val maybeProp = methods.find(it => it.getName == structField.name) + if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${structField.name} is not found among available props, which are: ${methods.map(_.getName).mkString(", ")}") + val fieldName = structField.name + val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls + val propDt = structField.dataType.asInstanceOf[DataTypeWithClass] + val fieldValue = Invoke( + inputObject, + maybeProp.get.getName, + inferExternalType(propClass), + returnNullable = propDt.nullable + ) + expressions.Literal(fieldName) :: serializerFor(fieldValue, TypeToken.of(propClass), propDt match { case c: ComplexWrapper => Some(c) case _ => None }) :: Nil + } + val nonNullOutput = CreateNamedStruct(fields.flatten.seq) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + case _ => - throw new UnsupportedOperationException(s"No Encoder found for $typeToken.") + throw new UnsupportedOperationException(s"No Encoder found for $typeToken in serializerFor. $otherTypeWrapper") } diff --git a/core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala b/core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala index 67feba26..01a2e3fb 100644 --- a/core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala +++ b/core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql -import java.beans.{Introspector, PropertyDescriptor} - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ import org.apache.spark.sql.catalyst.SerializerBuildHelper._ @@ -33,6 +31,8 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, WalkedTypePa import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import java.beans.{Introspector, PropertyDescriptor} + /** * A helper trait to create [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder]]s @@ -440,11 +440,79 @@ object KotlinReflection extends KotlinReflection { UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(t.cls)) + case StructType(elementType: Array[StructField]) => + val cls = t.cls + + val arguments = elementType.map { field => + val dataType = field.dataType.asInstanceOf[DataTypeWithClass] + val nullable = dataType.nullable + val clsName = getClassNameFromType(getType(dataType.cls)) + val newTypePath = walkedTypePath.recordField(clsName, field.name) + + // For tuples, we based grab the inner fields by ordinal instead of name. + val newPath = deserializerFor( + getType(dataType.cls), + addToPath(path, field.name, dataType.dt, newTypePath), + newTypePath, + Some(dataType).filter(_.isInstanceOf[ComplexWrapper]) + ) + expressionWithNullSafety( + newPath, + nullable = nullable, + newTypePath + ) + } + val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) + + org.apache.spark.sql.catalyst.expressions.If( + IsNull(path), + org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)), + newInstance + ) + + case _ => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath) } } + + case t if definedByConstructorParams(t) => + val params = getConstructorParameters(t) + + val cls = getClassFromType(tpe) + + val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => + val Schema(dataType, nullable) = schemaFor(fieldType) + val clsName = getClassNameFromType(fieldType) + val newTypePath = walkedTypePath.recordField(clsName, fieldName) + + // For tuples, we based grab the inner fields by ordinal instead of name. + val newPath = if (cls.getName startsWith "scala.Tuple") { + deserializerFor( + fieldType, + addToPathOrdinal(path, i, dataType, newTypePath), + newTypePath) + } else { + deserializerFor( + fieldType, + addToPath(path, fieldName, dataType, newTypePath), + newTypePath) + } + expressionWithNullSafety( + newPath, + nullable = nullable, + newTypePath) + } + + val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) + + org.apache.spark.sql.catalyst.expressions.If( + IsNull(path), + org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)), + newInstance + ) + case _ => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath) @@ -519,7 +587,7 @@ object KotlinReflection extends KotlinReflection { def toCatalystArray(input: Expression, elementType: `Type`, predefinedDt: Option[DataTypeWithClass] = None): Expression = { predefinedDt.map(_.dt).getOrElse(dataTypeFor(elementType)) match { - case dt:StructType => + case dt: StructType => val clsName = getClassNameFromType(elementType) val newPath = walkedTypePath.recordArray(clsName) createSerializerForMapObjects(input, ObjectType(predefinedDt.get.cls), @@ -662,32 +730,6 @@ object KotlinReflection extends KotlinReflection { createSerializerForUserDefinedType(inputObject, udt, udtClass) // - - case t if definedByConstructorParams(t) => - if (seenTypeSet.contains(t)) { - throw new UnsupportedOperationException( - s"cannot have circular references in class, but got the circular reference of class $t") - } - - val params = getConstructorParameters(t) - val fields = params.map { case (fieldName, fieldType) => - if (javaKeywords.contains(fieldName)) { - throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + - "cannot be used as field name\n" + walkedTypePath) - } - - // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul - // is necessary here. Because for a nullable nested inputObject with struct data - // type, e.g. StructType(IntegerType, StringType), it will return nullable=true - // for IntegerType without KnownNotNull. And that's what we do not expect to. - val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType), - returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) - val clsName = getClassNameFromType(fieldType) - val newPath = walkedTypePath.recordField(clsName, fieldName) - (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t)) - } - createSerializerForObject(inputObject, fields) - case _ if predefinedDt.isDefined => predefinedDt.get match { case dataType: KDataTypeWrapper => @@ -735,12 +777,66 @@ object KotlinReflection extends KotlinReflection { ) case ArrayType(elementType, _) => toCatalystArray(inputObject, getType(elementType.asInstanceOf[DataTypeWithClass].cls), Some(elementType.asInstanceOf[DataTypeWithClass])) + + case StructType(elementType: Array[StructField]) => + val cls = otherTypeWrapper.cls + val names = elementType.map(_.name) + + val beanInfo = Introspector.getBeanInfo(cls) + val methods = beanInfo.getMethodDescriptors.filter(it => names.contains(it.getName)) + + + val fields = elementType.map { structField => + + val maybeProp = methods.find(it => it.getName == structField.name) + if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${structField.name} is not found among available props, which are: ${methods.map(_.getName).mkString(", ")}") + val fieldName = structField.name + val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls + val propDt = structField.dataType.asInstanceOf[DataTypeWithClass] + val fieldValue = Invoke( + inputObject, + maybeProp.get.getName, + inferExternalType(propClass), + returnNullable = propDt.nullable + ) + val newPath = walkedTypePath.recordField(propClass.getName, fieldName) + (fieldName, serializerFor(fieldValue, getType(propClass), newPath, seenTypeSet, if (propDt.isInstanceOf[ComplexWrapper]) Some(propDt) else None)) + + } + createSerializerForObject(inputObject, fields) + case _ => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath) } } + + case t if definedByConstructorParams(t) => + if (seenTypeSet.contains(t)) { + throw new UnsupportedOperationException( + s"cannot have circular references in class, but got the circular reference of class $t") + } + + val params = getConstructorParameters(t) + val fields = params.map { case (fieldName, fieldType) => + if (javaKeywords.contains(fieldName)) { + throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + + "cannot be used as field name\n" + walkedTypePath) + } + + // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul + // is necessary here. Because for a nullable nested inputObject with struct data + // type, e.g. StructType(IntegerType, StringType), it will return nullable=true + // for IntegerType without KnownNotNull. And that's what we do not expect to. + val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType), + returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) + val clsName = getClassNameFromType(fieldType) + val newPath = walkedTypePath.recordField(clsName, fieldName) + (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t)) + } + createSerializerForObject(inputObject, fields) + case _ => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath) diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt index 57d7d682..1df6ed17 100644 --- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt @@ -36,6 +36,7 @@ import org.apache.spark.sql.streaming.GroupStateTimeout import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.* import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions +import scala.* import scala.collection.Seq import scala.reflect.`ClassTag$` import java.beans.PropertyDescriptor @@ -122,8 +123,6 @@ inline fun List.toDS(spark: SparkSession): Dataset = * It creates encoder for any given supported type T * * Supported types are data classes, primitives, and Lists, Maps and Arrays containing them - * are you here? - * Pavel?? * @param T type, supported by Spark * @return generated encoder */ @@ -141,6 +140,7 @@ fun generateEncoder(type: KType, cls: KClass<*>): Encoder { private fun isSupportedClass(cls: KClass<*>): Boolean = cls.isData || cls.isSubclassOf(Map::class) || cls.isSubclassOf(Iterable::class) + || cls.isSubclassOf(Product::class) || cls.java.isArray @Suppress("UNCHECKED_CAST") @@ -418,6 +418,20 @@ fun schema(type: KType, map: Map = mapOf()): DataType { ) KDataTypeWrapper(structType, klass.java, true) } + klass.isSubclassOf(Product::class) -> { + val params = type.arguments.mapIndexed { i, it -> + "_${i + 1}" to it.type!! + } + + val structType = DataTypes.createStructType( + params.map { (fieldName, fieldType) -> + val dataType = schema(fieldType, types) + KStructField(fieldName, StructField(fieldName, dataType, fieldType.isMarkedNullable, Metadata.empty())) + }.toTypedArray() + ) + + KComplexTypeWrapper(structType, klass.java, true) + } else -> throw IllegalArgumentException("$type is unsupported") } } @@ -430,6 +444,8 @@ enum class SparkLogLevel { ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN } +val timestampDt = `TimestampType$`.`MODULE$` +val dateDt = `DateType$`.`MODULE$` private val knownDataTypes = mapOf( Byte::class to DataTypes.ByteType, Short::class to DataTypes.ShortType, @@ -439,10 +455,10 @@ private val knownDataTypes = mapOf( Float::class to DataTypes.FloatType, Double::class to DataTypes.DoubleType, String::class to DataTypes.StringType, - LocalDate::class to `DateType$`.`MODULE$`, - Date::class to `DateType$`.`MODULE$`, - Timestamp::class to `TimestampType$`.`MODULE$`, - Instant::class to `TimestampType$`.`MODULE$` + LocalDate::class to dateDt, + Date::class to dateDt, + Timestamp::class to timestampDt, + Instant::class to timestampDt ) private fun transitiveMerge(a: Map, b: Map): Map { @@ -459,4 +475,4 @@ class Memoize1(val f: (T) -> R) : (T) -> R { private fun ((T) -> R).memoize(): (T) -> R = Memoize1(this) -private val memoizedSchema = { x: KType -> schema(x) }.memoize() \ No newline at end of file +private val memoizedSchema = { x: KType -> schema(x) }.memoize() diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt index 81cb1761..2696c405 100644 --- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt @@ -1,4 +1,23 @@ -@file:Suppress("NOTHING_TO_INLINE", "RemoveExplicitTypeArguments") +/*- + * =LICENSE= + * Kotlin Spark API: API for Spark 2.4+ (Scala 2.12) + * ---------- + * Copyright (C) 2019 - 2021 JetBrains + * ---------- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * =LICENSEEND= + */ +@file:Suppress("NOTHING_TO_INLINE", "RemoveExplicitTypeArguments", "unused") package org.jetbrains.kotlinx.spark.api diff --git a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt index 08ffee12..b8176842 100644 --- a/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt +++ b/kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkHelper.kt @@ -51,12 +51,6 @@ inline fun withSpark(props: Map = emptyMap(), master: String = "loc } -/** - * Pavel hello! - * Hello, World! - * How are you?? - * - */ @JvmOverloads inline fun withSpark(builder: Builder, logLevel: SparkLogLevel = ERROR, func: KSparkSession.() -> Unit) { builder diff --git a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt index 34e41482..2782695a 100644 --- a/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt +++ b/kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt @@ -26,6 +26,10 @@ import org.apache.spark.sql.streaming.GroupState import org.apache.spark.sql.streaming.GroupStateTimeout import scala.collection.Seq import org.apache.spark.sql.Dataset +import scala.Product +import scala.Tuple1 +import scala.Tuple2 +import scala.Tuple3 import java.io.Serializable import java.sql.Date import java.sql.Timestamp @@ -302,10 +306,32 @@ class ApiTest : ShouldSpec({ val dataset = dsOf(Timestamp(0L) to 2) dataset.show() } + should("Be able to serialize Scala Tuples including data classes") { + val dataset = dsOf( + Tuple2("a", Tuple3("a", 1, LonLat(1.0, 1.0))), + Tuple2("b", Tuple3("b", 2, LonLat(1.0, 2.0))), + ) + dataset.show() + val asList = dataset.takeAsList(2) + asList.first() shouldBe Tuple2("a", Tuple3("a", 1, LonLat(1.0, 1.0))) + } + should("Be able to serialize data classes with tuples") { + val dataset = dsOf( + DataClassWithTuple(Tuple3(5L, "test", Tuple1(""))), + DataClassWithTuple(Tuple3(6L, "tessst", Tuple1(""))), + ) + + dataset.show() + val asList = dataset.takeAsList(2) + asList.first().tuple shouldBe Tuple3(5L, "test", Tuple1("")) + } } } }) +data class DataClassWithTuple(val tuple: T) + + data class LonLat(val lon: Double, val lat: Double) data class Test(val id: Long, val data: Array>) { override fun equals(other: Any?): Boolean { diff --git a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt index d5adc3bd..eef97acf 100644 --- a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt +++ b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt @@ -33,6 +33,7 @@ import org.apache.spark.sql.streaming.GroupStateTimeout import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.* import org.jetbrains.kotinx.spark.extensions.KSparkExtensions +import scala.* import scala.reflect.ClassTag import java.beans.PropertyDescriptor import java.math.BigDecimal @@ -47,6 +48,7 @@ import kotlin.reflect.KType import kotlin.reflect.full.findAnnotation import kotlin.reflect.full.isSubclassOf import kotlin.reflect.full.primaryConstructor +import kotlin.reflect.jvm.jvmErasure import kotlin.reflect.typeOf @JvmField @@ -139,6 +141,7 @@ fun generateEncoder(type: KType, cls: KClass<*>): Encoder { private fun isSupportedClass(cls: KClass<*>): Boolean = cls.isData || cls.isSubclassOf(Map::class) || cls.isSubclassOf(Iterable::class) + || cls.isSubclassOf(Product::class) || cls.java.isArray private fun kotlinClassEncoder(schema: DataType, kClass: KClass<*>): Encoder { @@ -408,6 +411,20 @@ fun schema(type: KType, map: Map = mapOf()): DataType { ) KDataTypeWrapper(structType, klass.java, true) } + klass.isSubclassOf(Product::class) -> { + val params = type.arguments.mapIndexed { i, it -> + "_${i + 1}" to it.type!! + } + + val structType = DataTypes.createStructType( + params.map { (fieldName, fieldType) -> + val dataType = schema(fieldType, types) + KStructField(fieldName, StructField(fieldName, dataType, fieldType.isMarkedNullable, Metadata.empty())) + }.toTypedArray() + ) + + KComplexTypeWrapper(structType, klass.java, true) + } else -> throw IllegalArgumentException("$type is unsupported") } } diff --git a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt index 00562001..bc9ced24 100644 --- a/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt +++ b/kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt @@ -1,4 +1,23 @@ -@file:Suppress("NOTHING_TO_INLINE", "RemoveExplicitTypeArguments") +/*- + * =LICENSE= + * Kotlin Spark API: API for Spark 3.0+ (Scala 2.12) + * ---------- + * Copyright (C) 2019 - 2021 JetBrains + * ---------- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * =LICENSEEND= + */ +@file:Suppress("NOTHING_TO_INLINE", "RemoveExplicitTypeArguments", "unused") package org.jetbrains.kotlinx.spark.api diff --git a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt index 3522a68e..555b0b14 100644 --- a/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt +++ b/kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt @@ -22,10 +22,14 @@ import ch.tutteli.atrium.domain.builders.migration.asExpect import ch.tutteli.atrium.verbs.expect import io.kotest.core.spec.style.ShouldSpec import io.kotest.matchers.shouldBe +import scala.Tuple1 +import scala.Tuple2 +import scala.Tuple3 import org.apache.spark.sql.streaming.GroupState import org.apache.spark.sql.streaming.GroupStateTimeout import scala.collection.Seq import org.apache.spark.sql.Dataset +import scala.Product import java.io.Serializable import java.sql.Date import java.sql.Timestamp @@ -324,10 +328,31 @@ class ApiTest : ShouldSpec({ val dataset = dsOf(Timestamp(0L) to 2) dataset.show() } + should("Be able to serialize Scala Tuples including data classes") { + val dataset = dsOf( + Tuple2("a", Tuple3("a", 1, LonLat(1.0, 1.0))), + Tuple2("b", Tuple3("b", 2, LonLat(1.0, 2.0))), + ) + dataset.show() + val asList = dataset.takeAsList(2) + asList.first() shouldBe Tuple2("a", Tuple3("a", 1, LonLat(1.0, 1.0))) + } + should("Be able to serialize data classes with tuples") { + val dataset = dsOf( + DataClassWithTuple(Tuple3(5L, "test", Tuple1(""))), + DataClassWithTuple(Tuple3(6L, "tessst", Tuple1(""))), + ) + + dataset.show() + val asList = dataset.takeAsList(2) + asList.first().tuple shouldBe Tuple3(5L, "test", Tuple1("")) + } } } }) +data class DataClassWithTuple(val tuple: T) + data class LonLat(val lon: Double, val lat: Double) // (data) class must be Serializable to be broadcast