Skip to content

Commit

Permalink
fixed encoding isSomething names in data classes (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jolanrensen authored Jul 13, 2022
1 parent 4aabc9c commit 9d3f364
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 37 deletions.
72 changes: 39 additions & 33 deletions core/src/main/scala/org/apache/spark/sql/KotlinReflection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.util.Utils

import java.beans.{Introspector, PropertyDescriptor}
import java.lang.Exception
import java.lang.reflect.Method


/**
Expand Down Expand Up @@ -212,11 +213,11 @@ object KotlinReflection extends KotlinReflection {
* @param walkedTypePath The paths from top to bottom to access current field when deserializing.
*/
private def deserializerFor(
tpe: `Type`,
path: Expression,
walkedTypePath: WalkedTypePath,
predefinedDt: Option[DataTypeWithClass] = None
): Expression = cleanUpReflectionObjects {
tpe: `Type`,
path: Expression,
walkedTypePath: WalkedTypePath,
predefinedDt: Option[DataTypeWithClass] = None
): Expression = cleanUpReflectionObjects {
baseType(tpe) match {

//<editor-fold desc="Description">
Expand Down Expand Up @@ -685,18 +686,18 @@ object KotlinReflection extends KotlinReflection {
* internal representation.
*/
private def serializerFor(
inputObject: Expression,
tpe: `Type`,
walkedTypePath: WalkedTypePath,
seenTypeSet: Set[`Type`] = Set.empty,
predefinedDt: Option[DataTypeWithClass] = None,
): Expression = cleanUpReflectionObjects {
inputObject: Expression,
tpe: `Type`,
walkedTypePath: WalkedTypePath,
seenTypeSet: Set[`Type`] = Set.empty,
predefinedDt: Option[DataTypeWithClass] = None,
): Expression = cleanUpReflectionObjects {

def toCatalystArray(
input: Expression,
elementType: `Type`,
predefinedDt: Option[DataTypeWithClass] = None,
): Expression = {
input: Expression,
elementType: `Type`,
predefinedDt: Option[DataTypeWithClass] = None,
): Expression = {
val dataType = predefinedDt
.map(_.dt)
.getOrElse {
Expand All @@ -705,7 +706,7 @@ object KotlinReflection extends KotlinReflection {

dataType match {

case dt @ (MapType(_, _, _) | ArrayType(_, _) | StructType(_)) => {
case dt@(MapType(_, _, _) | ArrayType(_, _) | StructType(_)) => {
val clsName = getClassNameFromType(elementType)
val newPath = walkedTypePath.recordArray(clsName)
createSerializerForMapObjects(
Expand All @@ -726,7 +727,7 @@ object KotlinReflection extends KotlinReflection {
// case dt: ByteType =>
// createSerializerForPrimitiveArray(input, dt)

case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => {
case dt@(BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => {
val cls = input.dataType.asInstanceOf[ObjectType].cls
if (cls.isArray && cls.getComponentType.isPrimitive) {
createSerializerForPrimitiveArray(input, dt)
Expand Down Expand Up @@ -945,11 +946,11 @@ object KotlinReflection extends KotlinReflection {
// Kotlin specific cases
case t if predefinedDt.isDefined => {

// if (seenTypeSet.contains(t)) {
// throw new UnsupportedOperationException(
// s"cannot have circular references in class, but got the circular reference of class $t"
// )
// }
// if (seenTypeSet.contains(t)) {
// throw new UnsupportedOperationException(
// s"cannot have circular references in class, but got the circular reference of class $t"
// )
// }

predefinedDt.get match {

Expand All @@ -959,18 +960,20 @@ object KotlinReflection extends KotlinReflection {
val properties = getJavaBeanReadableProperties(cls)
val structFields = dataType.dt.fields.map(_.asInstanceOf[KStructField])
val fields: Array[(String, Expression)] = structFields.map { structField =>
val maybeProp = properties.find(it => it.getReadMethod.getName == structField.getterName)
if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${
structField.name
} is not found among available props, which are: ${properties.map(_.getName).mkString(", ")}"
)
val maybeProp = properties.find {
_.getName == structField.getterName
}
if (maybeProp.isEmpty)
throw new IllegalArgumentException(
s"Field ${structField.name} is not found among available props, which are: ${properties.map(_.getName).mkString(", ")}"
)
val fieldName = structField.name
val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls
val propDt = structField.dataType.asInstanceOf[DataTypeWithClass]

val fieldValue = Invoke(
inputObject,
maybeProp.get.getReadMethod.getName,
maybeProp.get.getName,
inferExternalType(propClass),
returnNullable = structField.nullable
)
Expand Down Expand Up @@ -1124,11 +1127,14 @@ object KotlinReflection extends KotlinReflection {
)
}

def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
def getJavaBeanReadableProperties(beanClass: Class[_]): Array[Method] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
.filterNot(_.getName == "declaringClass")
.filter(_.getReadMethod != null)
beanInfo
.getMethodDescriptors
.filter { it => it.getName.startsWith("is") || it.getName.startsWith("get") }
.filterNot { _.getName == "getClass" }
.filterNot { _.getName == "getDeclaringClass" }
.map { _.getMethod }
}

/*
Expand Down Expand Up @@ -1296,7 +1302,7 @@ object KotlinReflection extends KotlinReflection {
val params = method.typeSignature.paramLists.head
// Check that the needed params are the same length and of matching types
params.size == paramTypes.tail.size &&
params.zip(paramTypes.tail).forall { case(ps, pc) =>
params.zip(paramTypes.tail).forall { case (ps, pc) =>
ps.typeSignature.typeSymbol == mirror.classSymbol(pc)
}
}.map { applyMethodSymbol =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,20 +250,24 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
}

klass.isData -> {

val structType = StructType(
klass
.primaryConstructor!!
.parameters
.filter { it.findAnnotation<Transient>() == null }
.map {
val projectedType = types[it.type.toString()] ?: it.type

val readMethodName = when {
it.name!!.startsWith("is") -> it.name!!
else -> "get${it.name!!.replaceFirstChar { it.uppercase() }}"
}

val propertyDescriptor = PropertyDescriptor(
/* propertyName = */ it.name,
/* beanClass = */ klass.java,
/* readMethodName = */ "is" + it.name?.replaceFirstChar {
if (it.isLowerCase()) it.titlecase(Locale.getDefault())
else it.toString()
},
/* readMethodName = */ readMethodName,
/* writeMethodName = */ null
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ class EncodingTest : ShouldSpec({
val dataset = ints.toDS()
dataset.collectAsList() shouldBe ints
}

should("handle data classes with isSomething") {
val dataClasses = listOf(
IsSomethingClass(true, false, true, 1.0, 2.0, 0.0),
IsSomethingClass(false, true, true, 1.0, 2.0, 0.0),
)
val dataset = dataClasses.toDS().showDS()
dataset.collectAsList() shouldBe dataClasses
}
}
}
context("known dataTypes") {
Expand Down Expand Up @@ -174,6 +183,16 @@ class EncodingTest : ShouldSpec({
asList.first() shouldBe t("a", t("a", 1, LonLat(1.0, 1.0)))
}

should("Be able to serialize Scala Tuples including isSomething data classes") {
val dataset = dsOf(
t("a", t("a", 1, IsSomethingClass(true, false, true, 1.0, 2.0, 0.0))),
t("b", t("b", 2, IsSomethingClass(false, true, true, 1.0, 2.0, 0.0))),
)
dataset.show()
val asList = dataset.takeAsList(2)
asList.first() shouldBe t("a", t("a", 1, IsSomethingClass(true, false, true, 1.0, 2.0, 0.0)))
}

should("Be able to serialize data classes with tuples") {
val dataset = dsOf(
DataClassWithTuple(t(5L, "test", t(""))),
Expand Down Expand Up @@ -495,6 +514,15 @@ class EncodingTest : ShouldSpec({
}
})

data class IsSomethingClass(
val enabled: Boolean,
val isEnabled: Boolean,
val getEnabled: Boolean,
val double: Double,
val isDouble: Double,
val getDouble: Double
)

data class DataClassWithTuple<T : Product>(val tuple: T)

data class LonLat(val lon: Double, val lat: Double)
Expand Down

0 comments on commit 9d3f364

Please sign in to comment.