Skip to content

Commit

Permalink
feat: adds support for Tuple encoding
Browse files Browse the repository at this point in the history
Up to this moment, there was no ability to work with `Tuple`s in Kotlin API for Apache Spark, which stopped us from

1. Mixing Scala and Kotlin code in one project
2. Call some operations like `select` returning typed tuples

Also, potentially it could bring unavoidable performance hits when we're forcing users to use explicit Tuple → data class conversions. Costs should be negligible, but we can't really measure it and, consequentially, should give users a choice of Kotlin idiomatic way or potentially more performant code,

We thank @Jolanrensen for their commitment to the project and the huge effort to fix this issue. Thank you very much!
  • Loading branch information
Jolanrensen authored and asm0dey committed May 7, 2021
1 parent b18f889 commit aa11744
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import java.lang.reflect.Type
import java.lang.{Iterable => JIterable}
import java.time.LocalDate
import java.util.{Iterator => JIterator, List => JList, Map => JMap}

import com.google.common.reflect.TypeToken
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -399,10 +398,38 @@ object KotlinReflection {
getPath,
customCollectionCls = Some(predefinedDt.get.cls))

case StructType(elementType: Array[StructField]) =>
val cls = t.cls

val arguments = elementType.map { field =>
val dataType = field.dataType.asInstanceOf[DataTypeWithClass]
val nullable = dataType.nullable
val clsName = dataType.cls.getName
val fieldName = field.asInstanceOf[KStructField].delegate.name
val newPath = addToPath(fieldName)

deserializerFor(
TypeToken.of(dataType.cls),
Some(newPath),
Some(dataType).filter(_.isInstanceOf[ComplexWrapper])
)
}
val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)


if (path.nonEmpty) {
expressions.If(
IsNull(getPath),
expressions.Literal.create(null, ObjectType(cls)),
newInstance
)
} else {
newInstance
}

case _ =>
throw new UnsupportedOperationException(
s"No Encoder found for $typeToken")
s"No Encoder found for $typeToken in deserializerFor\n" + path)
}
}

Expand Down Expand Up @@ -608,8 +635,34 @@ object KotlinReflection {
case ArrayType(elementType, _) =>
toCatalystArray(inputObject, TypeToken.of(elementType.asInstanceOf[DataTypeWithClass].cls), Some(elementType.asInstanceOf[DataTypeWithClass]))

case StructType(elementType: Array[StructField]) =>
val cls = otherTypeWrapper.cls
val names = elementType.map(_.name)

val beanInfo = Introspector.getBeanInfo(cls)
val methods = beanInfo.getMethodDescriptors.filter(it => names.contains(it.getName))

val fields = elementType.map { structField =>

val maybeProp = methods.find(it => it.getName == structField.name)
if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${structField.name} is not found among available props, which are: ${methods.map(_.getName).mkString(", ")}")
val fieldName = structField.name
val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls
val propDt = structField.dataType.asInstanceOf[DataTypeWithClass]
val fieldValue = Invoke(
inputObject,
maybeProp.get.getName,
inferExternalType(propClass),
returnNullable = propDt.nullable
)
expressions.Literal(fieldName) :: serializerFor(fieldValue, TypeToken.of(propClass), propDt match { case c: ComplexWrapper => Some(c) case _ => None }) :: Nil
}
val nonNullOutput = CreateNamedStruct(fields.flatten.seq)
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)

case _ =>
throw new UnsupportedOperationException(s"No Encoder found for $typeToken.")
throw new UnsupportedOperationException(s"No Encoder found for $typeToken in serializerFor. $otherTypeWrapper")

}

Expand Down
154 changes: 125 additions & 29 deletions core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

package org.apache.spark.sql

import java.beans.{Introspector, PropertyDescriptor}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
import org.apache.spark.sql.catalyst.SerializerBuildHelper._
Expand All @@ -33,6 +31,8 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, WalkedTypePa
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

import java.beans.{Introspector, PropertyDescriptor}


/**
* A helper trait to create [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder]]s
Expand Down Expand Up @@ -440,11 +440,79 @@ object KotlinReflection extends KotlinReflection {

UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(t.cls))

case StructType(elementType: Array[StructField]) =>
val cls = t.cls

val arguments = elementType.map { field =>
val dataType = field.dataType.asInstanceOf[DataTypeWithClass]
val nullable = dataType.nullable
val clsName = getClassNameFromType(getType(dataType.cls))
val newTypePath = walkedTypePath.recordField(clsName, field.name)

// For tuples, we based grab the inner fields by ordinal instead of name.
val newPath = deserializerFor(
getType(dataType.cls),
addToPath(path, field.name, dataType.dt, newTypePath),
newTypePath,
Some(dataType).filter(_.isInstanceOf[ComplexWrapper])
)
expressionWithNullSafety(
newPath,
nullable = nullable,
newTypePath
)
}
val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)

org.apache.spark.sql.catalyst.expressions.If(
IsNull(path),
org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)),
newInstance
)


case _ =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath)
}
}

case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)

val cls = getClassFromType(tpe)

val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
val clsName = getClassNameFromType(fieldType)
val newTypePath = walkedTypePath.recordField(clsName, fieldName)

// For tuples, we based grab the inner fields by ordinal instead of name.
val newPath = if (cls.getName startsWith "scala.Tuple") {
deserializerFor(
fieldType,
addToPathOrdinal(path, i, dataType, newTypePath),
newTypePath)
} else {
deserializerFor(
fieldType,
addToPath(path, fieldName, dataType, newTypePath),
newTypePath)
}
expressionWithNullSafety(
newPath,
nullable = nullable,
newTypePath)
}

val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)

org.apache.spark.sql.catalyst.expressions.If(
IsNull(path),
org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)),
newInstance
)

case _ =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath)
Expand Down Expand Up @@ -519,7 +587,7 @@ object KotlinReflection extends KotlinReflection {

def toCatalystArray(input: Expression, elementType: `Type`, predefinedDt: Option[DataTypeWithClass] = None): Expression = {
predefinedDt.map(_.dt).getOrElse(dataTypeFor(elementType)) match {
case dt:StructType =>
case dt: StructType =>
val clsName = getClassNameFromType(elementType)
val newPath = walkedTypePath.recordArray(clsName)
createSerializerForMapObjects(input, ObjectType(predefinedDt.get.cls),
Expand Down Expand Up @@ -662,32 +730,6 @@ object KotlinReflection extends KotlinReflection {
createSerializerForUserDefinedType(inputObject, udt, udtClass)
//</editor-fold>


case t if definedByConstructorParams(t) =>
if (seenTypeSet.contains(t)) {
throw new UnsupportedOperationException(
s"cannot have circular references in class, but got the circular reference of class $t")
}

val params = getConstructorParameters(t)
val fields = params.map { case (fieldName, fieldType) =>
if (javaKeywords.contains(fieldName)) {
throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
"cannot be used as field name\n" + walkedTypePath)
}

// SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
// is necessary here. Because for a nullable nested inputObject with struct data
// type, e.g. StructType(IntegerType, StringType), it will return nullable=true
// for IntegerType without KnownNotNull. And that's what we do not expect to.
val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType),
returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
val clsName = getClassNameFromType(fieldType)
val newPath = walkedTypePath.recordField(clsName, fieldName)
(fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t))
}
createSerializerForObject(inputObject, fields)

case _ if predefinedDt.isDefined =>
predefinedDt.get match {
case dataType: KDataTypeWrapper =>
Expand Down Expand Up @@ -735,12 +777,66 @@ object KotlinReflection extends KotlinReflection {
)
case ArrayType(elementType, _) =>
toCatalystArray(inputObject, getType(elementType.asInstanceOf[DataTypeWithClass].cls), Some(elementType.asInstanceOf[DataTypeWithClass]))

case StructType(elementType: Array[StructField]) =>
val cls = otherTypeWrapper.cls
val names = elementType.map(_.name)

val beanInfo = Introspector.getBeanInfo(cls)
val methods = beanInfo.getMethodDescriptors.filter(it => names.contains(it.getName))


val fields = elementType.map { structField =>

val maybeProp = methods.find(it => it.getName == structField.name)
if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${structField.name} is not found among available props, which are: ${methods.map(_.getName).mkString(", ")}")
val fieldName = structField.name
val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls
val propDt = structField.dataType.asInstanceOf[DataTypeWithClass]
val fieldValue = Invoke(
inputObject,
maybeProp.get.getName,
inferExternalType(propClass),
returnNullable = propDt.nullable
)
val newPath = walkedTypePath.recordField(propClass.getName, fieldName)
(fieldName, serializerFor(fieldValue, getType(propClass), newPath, seenTypeSet, if (propDt.isInstanceOf[ComplexWrapper]) Some(propDt) else None))

}
createSerializerForObject(inputObject, fields)

case _ =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath)

}
}

case t if definedByConstructorParams(t) =>
if (seenTypeSet.contains(t)) {
throw new UnsupportedOperationException(
s"cannot have circular references in class, but got the circular reference of class $t")
}

val params = getConstructorParameters(t)
val fields = params.map { case (fieldName, fieldType) =>
if (javaKeywords.contains(fieldName)) {
throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
"cannot be used as field name\n" + walkedTypePath)
}

// SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
// is necessary here. Because for a nullable nested inputObject with struct data
// type, e.g. StructType(IntegerType, StringType), it will return nullable=true
// for IntegerType without KnownNotNull. And that's what we do not expect to.
val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType),
returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
val clsName = getClassNameFromType(fieldType)
val newPath = walkedTypePath.recordField(clsName, fieldName)
(fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t))
}
createSerializerForObject(inputObject, fields)

case _ =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.streaming.GroupStateTimeout
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.*
import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions
import scala.*
import scala.collection.Seq
import scala.reflect.`ClassTag$`
import java.beans.PropertyDescriptor
Expand Down Expand Up @@ -122,8 +123,6 @@ inline fun <reified T> List<T>.toDS(spark: SparkSession): Dataset<T> =
* It creates encoder for any given supported type T
*
* Supported types are data classes, primitives, and Lists, Maps and Arrays containing them
* are you here?
* Pavel??
* @param T type, supported by Spark
* @return generated encoder
*/
Expand All @@ -141,6 +140,7 @@ fun <T> generateEncoder(type: KType, cls: KClass<*>): Encoder<T> {
private fun isSupportedClass(cls: KClass<*>): Boolean = cls.isData
|| cls.isSubclassOf(Map::class)
|| cls.isSubclassOf(Iterable::class)
|| cls.isSubclassOf(Product::class)
|| cls.java.isArray

@Suppress("UNCHECKED_CAST")
Expand Down Expand Up @@ -418,6 +418,20 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
)
KDataTypeWrapper(structType, klass.java, true)
}
klass.isSubclassOf(Product::class) -> {
val params = type.arguments.mapIndexed { i, it ->
"_${i + 1}" to it.type!!
}

val structType = DataTypes.createStructType(
params.map { (fieldName, fieldType) ->
val dataType = schema(fieldType, types)
KStructField(fieldName, StructField(fieldName, dataType, fieldType.isMarkedNullable, Metadata.empty()))
}.toTypedArray()
)

KComplexTypeWrapper(structType, klass.java, true)
}
else -> throw IllegalArgumentException("$type is unsupported")
}
}
Expand All @@ -430,6 +444,8 @@ enum class SparkLogLevel {
ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN
}

val timestampDt = `TimestampType$`.`MODULE$`
val dateDt = `DateType$`.`MODULE$`
private val knownDataTypes = mapOf(
Byte::class to DataTypes.ByteType,
Short::class to DataTypes.ShortType,
Expand All @@ -439,10 +455,10 @@ private val knownDataTypes = mapOf(
Float::class to DataTypes.FloatType,
Double::class to DataTypes.DoubleType,
String::class to DataTypes.StringType,
LocalDate::class to `DateType$`.`MODULE$`,
Date::class to `DateType$`.`MODULE$`,
Timestamp::class to `TimestampType$`.`MODULE$`,
Instant::class to `TimestampType$`.`MODULE$`
LocalDate::class to dateDt,
Date::class to dateDt,
Timestamp::class to timestampDt,
Instant::class to timestampDt
)

private fun transitiveMerge(a: Map<String, KType>, b: Map<String, KType>): Map<String, KType> {
Expand All @@ -459,4 +475,4 @@ class Memoize1<in T, out R>(val f: (T) -> R) : (T) -> R {

private fun <T, R> ((T) -> R).memoize(): (T) -> R = Memoize1(this)

private val memoizedSchema = { x: KType -> schema(x) }.memoize()
private val memoizedSchema = { x: KType -> schema(x) }.memoize()
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
@file:Suppress("NOTHING_TO_INLINE", "RemoveExplicitTypeArguments")
/*-
* =LICENSE=
* Kotlin Spark API: API for Spark 2.4+ (Scala 2.12)
* ----------
* Copyright (C) 2019 - 2021 JetBrains
* ----------
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =LICENSEEND=
*/
@file:Suppress("NOTHING_TO_INLINE", "RemoveExplicitTypeArguments", "unused")

package org.jetbrains.kotlinx.spark.api

Expand Down
Loading

0 comments on commit aa11744

Please sign in to comment.