diff --git a/core/src/main/scala/org/apache/spark/sql/KotlinReflection.scala b/core/src/main/scala/org/apache/spark/sql/KotlinReflection.scala index f11012ef..4916ceb7 100644 --- a/core/src/main/scala/org/apache/spark/sql/KotlinReflection.scala +++ b/core/src/main/scala/org/apache/spark/sql/KotlinReflection.scala @@ -36,6 +36,7 @@ import org.apache.spark.util.Utils import java.beans.{Introspector, PropertyDescriptor} import java.lang.Exception +import java.lang.reflect.Method /** @@ -212,11 +213,11 @@ object KotlinReflection extends KotlinReflection { * @param walkedTypePath The paths from top to bottom to access current field when deserializing. */ private def deserializerFor( - tpe: `Type`, - path: Expression, - walkedTypePath: WalkedTypePath, - predefinedDt: Option[DataTypeWithClass] = None - ): Expression = cleanUpReflectionObjects { + tpe: `Type`, + path: Expression, + walkedTypePath: WalkedTypePath, + predefinedDt: Option[DataTypeWithClass] = None + ): Expression = cleanUpReflectionObjects { baseType(tpe) match { // @@ -685,18 +686,18 @@ object KotlinReflection extends KotlinReflection { * internal representation. */ private def serializerFor( - inputObject: Expression, - tpe: `Type`, - walkedTypePath: WalkedTypePath, - seenTypeSet: Set[`Type`] = Set.empty, - predefinedDt: Option[DataTypeWithClass] = None, - ): Expression = cleanUpReflectionObjects { + inputObject: Expression, + tpe: `Type`, + walkedTypePath: WalkedTypePath, + seenTypeSet: Set[`Type`] = Set.empty, + predefinedDt: Option[DataTypeWithClass] = None, + ): Expression = cleanUpReflectionObjects { def toCatalystArray( - input: Expression, - elementType: `Type`, - predefinedDt: Option[DataTypeWithClass] = None, - ): Expression = { + input: Expression, + elementType: `Type`, + predefinedDt: Option[DataTypeWithClass] = None, + ): Expression = { val dataType = predefinedDt .map(_.dt) .getOrElse { @@ -705,7 +706,7 @@ object KotlinReflection extends KotlinReflection { dataType match { - case dt @ (MapType(_, _, _) | ArrayType(_, _) | StructType(_)) => { + case dt@(MapType(_, _, _) | ArrayType(_, _) | StructType(_)) => { val clsName = getClassNameFromType(elementType) val newPath = walkedTypePath.recordArray(clsName) createSerializerForMapObjects( @@ -726,7 +727,7 @@ object KotlinReflection extends KotlinReflection { // case dt: ByteType => // createSerializerForPrimitiveArray(input, dt) - case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => { + case dt@(BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => { val cls = input.dataType.asInstanceOf[ObjectType].cls if (cls.isArray && cls.getComponentType.isPrimitive) { createSerializerForPrimitiveArray(input, dt) @@ -945,11 +946,11 @@ object KotlinReflection extends KotlinReflection { // Kotlin specific cases case t if predefinedDt.isDefined => { -// if (seenTypeSet.contains(t)) { -// throw new UnsupportedOperationException( -// s"cannot have circular references in class, but got the circular reference of class $t" -// ) -// } + // if (seenTypeSet.contains(t)) { + // throw new UnsupportedOperationException( + // s"cannot have circular references in class, but got the circular reference of class $t" + // ) + // } predefinedDt.get match { @@ -959,18 +960,20 @@ object KotlinReflection extends KotlinReflection { val properties = getJavaBeanReadableProperties(cls) val structFields = dataType.dt.fields.map(_.asInstanceOf[KStructField]) val fields: Array[(String, Expression)] = structFields.map { structField => - val maybeProp = properties.find(it => it.getReadMethod.getName == structField.getterName) - if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${ - structField.name - } is not found among available props, which are: ${properties.map(_.getName).mkString(", ")}" - ) + val maybeProp = properties.find { + _.getName == structField.getterName + } + if (maybeProp.isEmpty) + throw new IllegalArgumentException( + s"Field ${structField.name} is not found among available props, which are: ${properties.map(_.getName).mkString(", ")}" + ) val fieldName = structField.name val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls val propDt = structField.dataType.asInstanceOf[DataTypeWithClass] val fieldValue = Invoke( inputObject, - maybeProp.get.getReadMethod.getName, + maybeProp.get.getName, inferExternalType(propClass), returnNullable = structField.nullable ) @@ -1124,11 +1127,14 @@ object KotlinReflection extends KotlinReflection { ) } - def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { + def getJavaBeanReadableProperties(beanClass: Class[_]): Array[Method] = { val beanInfo = Introspector.getBeanInfo(beanClass) - beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") - .filterNot(_.getName == "declaringClass") - .filter(_.getReadMethod != null) + beanInfo + .getMethodDescriptors + .filter { it => it.getName.startsWith("is") || it.getName.startsWith("get") } + .filterNot { _.getName == "getClass" } + .filterNot { _.getName == "getDeclaringClass" } + .map { _.getMethod } } /* @@ -1296,7 +1302,7 @@ object KotlinReflection extends KotlinReflection { val params = method.typeSignature.paramLists.head // Check that the needed params are the same length and of matching types params.size == paramTypes.tail.size && - params.zip(paramTypes.tail).forall { case(ps, pc) => + params.zip(paramTypes.tail).forall { case (ps, pc) => ps.typeSignature.typeSymbol == mirror.classSymbol(pc) } }.map { applyMethodSymbol => diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt index b37674c8..f6d5e87b 100644 --- a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt @@ -250,6 +250,7 @@ fun schema(type: KType, map: Map = mapOf()): DataType { } klass.isData -> { + val structType = StructType( klass .primaryConstructor!! @@ -257,13 +258,16 @@ fun schema(type: KType, map: Map = mapOf()): DataType { .filter { it.findAnnotation() == null } .map { val projectedType = types[it.type.toString()] ?: it.type + + val readMethodName = when { + it.name!!.startsWith("is") -> it.name!! + else -> "get${it.name!!.replaceFirstChar { it.uppercase() }}" + } + val propertyDescriptor = PropertyDescriptor( /* propertyName = */ it.name, /* beanClass = */ klass.java, - /* readMethodName = */ "is" + it.name?.replaceFirstChar { - if (it.isLowerCase()) it.titlecase(Locale.getDefault()) - else it.toString() - }, + /* readMethodName = */ readMethodName, /* writeMethodName = */ null ) diff --git a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt index 1b7dec29..5d6affcb 100644 --- a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt +++ b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt @@ -107,6 +107,15 @@ class EncodingTest : ShouldSpec({ val dataset = ints.toDS() dataset.collectAsList() shouldBe ints } + + should("handle data classes with isSomething") { + val dataClasses = listOf( + IsSomethingClass(true, false, true, 1.0, 2.0, 0.0), + IsSomethingClass(false, true, true, 1.0, 2.0, 0.0), + ) + val dataset = dataClasses.toDS().showDS() + dataset.collectAsList() shouldBe dataClasses + } } } context("known dataTypes") { @@ -174,6 +183,16 @@ class EncodingTest : ShouldSpec({ asList.first() shouldBe t("a", t("a", 1, LonLat(1.0, 1.0))) } + should("Be able to serialize Scala Tuples including isSomething data classes") { + val dataset = dsOf( + t("a", t("a", 1, IsSomethingClass(true, false, true, 1.0, 2.0, 0.0))), + t("b", t("b", 2, IsSomethingClass(false, true, true, 1.0, 2.0, 0.0))), + ) + dataset.show() + val asList = dataset.takeAsList(2) + asList.first() shouldBe t("a", t("a", 1, IsSomethingClass(true, false, true, 1.0, 2.0, 0.0))) + } + should("Be able to serialize data classes with tuples") { val dataset = dsOf( DataClassWithTuple(t(5L, "test", t(""))), @@ -495,6 +514,15 @@ class EncodingTest : ShouldSpec({ } }) +data class IsSomethingClass( + val enabled: Boolean, + val isEnabled: Boolean, + val getEnabled: Boolean, + val double: Double, + val isDouble: Double, + val getDouble: Double +) + data class DataClassWithTuple(val tuple: T) data class LonLat(val lon: Double, val lat: Double)