Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for generic type to Any? erasure in new column creation #192

Merged
merged 1 commit into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,38 @@ internal fun KType.projectUpTo(superClass: KClass<*>): KType {
return current.withNullability(isMarkedNullable)
}

internal fun KType.replaceTypeParameters(): KType {
var replaced = false
val arguments = arguments.map {
val type = it.type
val newType = when {
type == null -> typeOf<Any?>()
type.classifier is KTypeParameter -> {
replaced = true
(type.classifier as KTypeParameter).upperBounds.firstOrNull() ?: typeOf<Any?>()
}
/**
* Changes generic type parameters to `Any?`, like `List<T> -> List<Any?>`.
* Works recursively as well.
*/
@PublishedApi
internal fun KType.eraseGenericTypeParameters(): KType {
fun KType.eraseRecursively(): Pair<Boolean, KType> {
var replaced = false
val arguments = arguments.map {
val type = it.type
val (replacedDownwards, newType) = when {
type == null -> typeOf<Any?>()

type.classifier is KTypeParameter -> {
replaced = true
(type.classifier as KTypeParameter).upperBounds.firstOrNull() ?: typeOf<Any?>()
}

else -> type
else -> type
}.eraseRecursively()

if (replacedDownwards) replaced = true

KTypeProjection.invariant(newType)
}
KTypeProjection.invariant(newType)
return Pair(
first = replaced,
second = if (replaced) jvmErasure.createType(arguments, isMarkedNullable) else this,
)
}
return if (replaced) jvmErasure.createType(arguments, isMarkedNullable)
else this

return eraseRecursively().second
}

internal fun inheritanceChain(subClass: KClass<*>, superClass: KClass<*>): List<Pair<KClass<*>, KType>> {
Expand Down Expand Up @@ -255,7 +270,7 @@ internal fun Iterable<KType>.commonTypeListifyValues(): KType {

else -> {
val kclass = commonParent(distinct.map { it.jvmErasure }) ?: return typeOf<Any>()
val projections = distinct.map { it.projectUpTo(kclass).replaceTypeParameters() }
val projections = distinct.map { it.projectUpTo(kclass).eraseGenericTypeParameters() }
require(projections.all { it.jvmErasure == kclass })
val arguments = List(kclass.typeParameters.size) { i ->
val projectionTypes = projections
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ internal fun Iterable<KType?>.commonType(): KType {
distinct.size == 1 -> distinct.single()!!
else -> {
val kclass = commonParent(distinct.map { it!!.jvmErasure }) ?: return typeOf<Any>()
val projections = distinct.map { it!!.projectUpTo(kclass).replaceTypeParameters() }
val projections = distinct.map { it!!.projectUpTo(kclass).eraseGenericTypeParameters() }
require(projections.all { it.jvmErasure == kclass })
val arguments = List(kclass.typeParameters.size) { i ->
val projectionTypes = projections
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnWithPath
import org.jetbrains.kotlinx.dataframe.impl.DataFrameReceiver
import org.jetbrains.kotlinx.dataframe.impl.DataRowImpl
import org.jetbrains.kotlinx.dataframe.impl.asList
import org.jetbrains.kotlinx.dataframe.impl.eraseGenericTypeParameters
import org.jetbrains.kotlinx.dataframe.impl.guessValueType
import org.jetbrains.kotlinx.dataframe.index
import org.jetbrains.kotlinx.dataframe.nrow
Expand All @@ -58,9 +59,25 @@ internal fun <T, R> ColumnsContainer<T>.newColumn(
): DataColumn<R> {
val (nullable, values) = computeValues(this as DataFrame<T>, expression)
return when (infer) {
Infer.Nulls -> DataColumn.create(name, values, type.withNullability(nullable), Infer.None)
Infer.Type -> DataColumn.createWithTypeInference(name, values, nullable)
Infer.None -> DataColumn.create(name, values, type, Infer.None)
Infer.Nulls -> DataColumn.create(
name = name,
values = values,
type = type.withNullability(nullable).eraseGenericTypeParameters(),
infer = Infer.None,
)

Infer.Type -> DataColumn.createWithTypeInference(
name = name,
values = values,
nullable = nullable,
)

Infer.None -> DataColumn.create(
name = name,
values = values,
type = type.eraseGenericTypeParameters(),
infer = Infer.None,
)
}
}

Expand Down
10 changes: 10 additions & 0 deletions core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package org.jetbrains.kotlinx.dataframe.api

import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.junit.Test
import kotlin.reflect.typeOf

class AddTests {

Expand All @@ -23,4 +25,12 @@ class AddTests {
df.add("y") { next()?.newValue() ?: 1 }
}
}

private fun <T> AnyFrame.addValue(value: T) = add("value") { listOf(value) }

@Test
fun `add with generic function`() {
val df = dataFrameOf("a")(1).addValue(2)
df["value"].type() shouldBe typeOf<List<Any?>>()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,31 @@ import io.kotest.assertions.throwables.shouldNotThrowAny
import io.kotest.matchers.shouldBe
import io.kotest.matchers.types.shouldBeInstanceOf
import org.intellij.lang.annotations.Language
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
import org.jetbrains.kotlinx.dataframe.type
import org.jetbrains.kotlinx.jupyter.api.MimeTypedResult
import org.jetbrains.kotlinx.jupyter.testkit.JupyterReplTestCase
import org.junit.Test
import kotlin.reflect.typeOf

class JupyterCodegenTests : JupyterReplTestCase() {

@Test
fun `codegen adding column with generic type function`() {
@Language("kts")
val res1 = exec(
"""
fun <T> AnyFrame.addValue(value: T) = add("value") { listOf(value) }
val df = dataFrameOf("a")(1).addValue(2)
""".trimIndent()
)
res1 shouldBe Unit
val res2 = execRaw("df") as AnyFrame

res2["value"].type shouldBe typeOf<List<Any?>>()
}

@Test
fun `codegen for enumerated frames`() {
@Language("kts")
Expand Down Expand Up @@ -78,6 +97,7 @@ class JupyterCodegenTests : JupyterReplTestCase() {
@Test
fun `codegen for chars that is forbidden in JVM identifiers`() {
val forbiddenChar = ";"

@Language("kts")
val res1 = exec(
"""
Expand All @@ -96,6 +116,7 @@ class JupyterCodegenTests : JupyterReplTestCase() {
@Test
fun `codegen for chars that is forbidden in JVM identifiers 1`() {
val forbiddenChar = "\\\\"

@Language("kts")
val res1 = exec(
"""
Expand Down