Skip to content

Commit

Permalink
[Compiler plugin] Use resolved type argument T of Iterable<T>.toDataF…
Browse files Browse the repository at this point in the history
…rame() _call_ instead of one from the return type of receiver *iterable*.toDataFrame()
  • Loading branch information
koperagen committed Nov 7, 2024
1 parent 9839b5c commit be0be0a
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.ConeNullability
import org.jetbrains.kotlin.fir.types.ConeStarProjection
import org.jetbrains.kotlin.fir.types.ConeTypeParameterType
import org.jetbrains.kotlin.fir.types.ConeTypeProjection
import org.jetbrains.kotlin.fir.types.canBeNull
import org.jetbrains.kotlin.fir.types.classId
import org.jetbrains.kotlin.fir.types.coneType
Expand All @@ -52,7 +53,6 @@ import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlin.name.StandardClassIds.List
import org.jetbrains.kotlin.types.checker.SimpleClassicTypeSystemContext.withNullability
import org.jetbrains.kotlinx.dataframe.codeGen.*
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.extensions.wrap
Expand All @@ -77,27 +77,32 @@ import java.util.*
class ToDataFrameDsl : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
val Arguments.body by dsl()
val Arguments.typeArg0: ConeTypeProjection? by arg(lens = Interpreter.Id)

override fun Arguments.interpret(): PluginDataFrameSchema {
val dsl = CreateDataFrameDslImplApproximation()
body(dsl, mapOf("explicitReceiver" to Interpreter.Success(receiver)))
receiver
body(dsl, mapOf("typeArg0" to Interpreter.Success(typeArg0)))
return PluginDataFrameSchema(dsl.columns)
}
}

class ToDataFrame : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
val Arguments.maxDepth: Number by arg(defaultValue = Present(DEFAULT_MAX_DEPTH))
val Arguments.typeArg0: ConeTypeProjection by arg(lens = Interpreter.Id)

override fun Arguments.interpret(): PluginDataFrameSchema {
return toDataFrame(maxDepth.toInt(), receiver, TraverseConfiguration())
return toDataFrame(maxDepth.toInt(), typeArg0, TraverseConfiguration())
}
}

class ToDataFrameDefault : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
val Arguments.typeArg0: ConeTypeProjection by arg(lens = Interpreter.Id)

override fun Arguments.interpret(): PluginDataFrameSchema {
return toDataFrame(DEFAULT_MAX_DEPTH, receiver, TraverseConfiguration())
return toDataFrame(DEFAULT_MAX_DEPTH, typeArg0, TraverseConfiguration())
}
}

Expand All @@ -115,14 +120,14 @@ private const val DEFAULT_MAX_DEPTH = 0

class Properties0 : AbstractInterpreter<Unit>() {
val Arguments.dsl: CreateDataFrameDslImplApproximation by arg()
val Arguments.explicitReceiver: FirExpression? by arg()
val Arguments.maxDepth: Int by arg()
val Arguments.body by dsl()
val Arguments.typeArg0: ConeTypeProjection by arg(lens = Interpreter.Id)

override fun Arguments.interpret() {
dsl.configuration.maxDepth = maxDepth
body(dsl.configuration.traverseConfiguration, emptyMap())
val schema = toDataFrame(dsl.configuration.maxDepth, explicitReceiver, dsl.configuration.traverseConfiguration)
val schema = toDataFrame(dsl.configuration.maxDepth, typeArg0, dsl.configuration.traverseConfiguration)
dsl.columns.addAll(schema.columns())
}
}
Expand Down Expand Up @@ -178,8 +183,8 @@ class Exclude1 : AbstractInterpreter<Unit>() {
@OptIn(SymbolInternals::class)
internal fun KotlinTypeFacade.toDataFrame(
maxDepth: Int,
explicitReceiver: FirExpression?,
traverseConfiguration: TraverseConfiguration
arg: ConeTypeProjection,
traverseConfiguration: TraverseConfiguration,
): PluginDataFrameSchema {
fun ConeKotlinType.isValueType() =
this.isArrayTypeOrNullableArrayType ||
Expand Down Expand Up @@ -290,8 +295,6 @@ internal fun KotlinTypeFacade.toDataFrame(
}
}

val receiver = explicitReceiver ?: return PluginDataFrameSchema.EMPTY
val arg = receiver.resolvedType.typeArguments.firstOrNull() ?: return PluginDataFrameSchema.EMPTY
return when {
arg.isStarProjection -> PluginDataFrameSchema.EMPTY
else -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,29 +90,43 @@ fun <T> KotlinTypeFacade.interpret(
val refinedArguments: RefinedArguments = functionCall.collectArgumentExpressions()

val defaultArguments = processor.expectedArguments.filter { it.defaultValue is Present }.map { it.name }.toSet()
val actualArgsMap = refinedArguments.associateBy { it.name.identifier }.toSortedMap()
val conflictingKeys = additionalArguments.keys intersect actualArgsMap.keys
val actualValueArguments = refinedArguments.associateBy { it.name.identifier }.toSortedMap()
val conflictingKeys = additionalArguments.keys intersect actualValueArguments.keys
if (conflictingKeys.isNotEmpty()) {
if (isTest) {
interpretationFrameworkError("Conflicting keys: $conflictingKeys")
}
return null
}
val expectedArgsMap = processor.expectedArguments
.filterNot { it.name.startsWith("typeArg") }
.associateBy { it.name }.toSortedMap().minus(additionalArguments.keys)

val unexpectedArguments = expectedArgsMap.keys - defaultArguments != actualArgsMap.keys - defaultArguments
val typeArguments = buildMap {
functionCall.typeArguments.forEachIndexed { index, firTypeProjection ->
val key = "typeArg$index"
val lens = expectedArgsMap[key]?.lens ?: return@forEachIndexed
val value: Any = if (lens == Interpreter.Id) {
firTypeProjection.toConeTypeProjection()
} else {
val type = firTypeProjection.toConeTypeProjection().type ?: session.builtinTypes.nullableAnyType.type
if (type is ConeIntersectionType) return@forEachIndexed
Marker(type)
}
put(key, Interpreter.Success(value))
}
}

val unexpectedArguments = (expectedArgsMap.keys - defaultArguments) != (actualValueArguments.keys + typeArguments.keys - defaultArguments)
if (unexpectedArguments) {
if (isTest) {
val message = buildString {
appendLine("ERROR: Different set of arguments")
appendLine("Implementation class: $processor")
appendLine("Not found in actual: ${expectedArgsMap.keys - actualArgsMap.keys}")
val diff = actualArgsMap.keys - expectedArgsMap.keys
appendLine("Not found in actual: ${expectedArgsMap.keys - actualValueArguments.keys}")
val diff = actualValueArguments.keys - expectedArgsMap.keys
appendLine("Passed, but not expected: ${diff}")
appendLine("add arguments to an interpeter:")
appendLine(diff.map { actualArgsMap[it] })
appendLine(diff.map { actualValueArguments[it] })
}
interpretationFrameworkError(message)
}
Expand All @@ -121,6 +135,7 @@ fun <T> KotlinTypeFacade.interpret(

val arguments = mutableMapOf<String, Interpreter.Success<Any?>>()
arguments += additionalArguments
arguments += typeArguments
val interpretationResults = refinedArguments.refinedArguments.mapNotNull {
val name = it.name.identifier
val expectedArgument = expectedArgsMap[name] ?: error("$processor $name")
Expand Down Expand Up @@ -269,17 +284,6 @@ fun <T> KotlinTypeFacade.interpret(
value?.let { value1 -> it.name.identifier to value1 }
}

functionCall.typeArguments.forEachIndexed { index, firTypeProjection ->
val type = firTypeProjection.toConeTypeProjection().type ?: session.builtinTypes.nullableAnyType.type
if (type is ConeIntersectionType) return@forEachIndexed
// val approximation = TypeApproximationImpl(
// type.classId!!.asFqNameString(),
// type.isMarkedNullable
// )
val approximation = Marker(type)
arguments["typeArg$index"] = Interpreter.Success(approximation)
}

return if (interpretationResults.size == refinedArguments.refinedArguments.size) {
arguments.putAll(interpretationResults)
when (val res = processor.interpret(arguments, this)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

@DataSchema
data class D(
val s: String
)

class Subtree(
val p: Int,
val l: List<Int>,
val ld: List<D>,
)

class Root(val a: Subtree)

class MyList(val l: List<Root?>): List<Root?> by l

fun box(): String {
val l = listOf(
Root(Subtree(123, listOf(1), listOf(D("ff")))),
null
)
val df = MyList(l).toDataFrame(maxDepth = 2)
df.compareSchemas(strict = true)
return "OK"
}
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,12 @@ public void testToDataFrame_column() {
runTest("testData/box/toDataFrame_column.kt");
}

@Test
@TestMetadata("toDataFrame_customIterable.kt")
public void testToDataFrame_customIterable() {
runTest("testData/box/toDataFrame_customIterable.kt");
}

@Test
@TestMetadata("toDataFrame_dataSchema.kt")
public void testToDataFrame_dataSchema() {
Expand Down

0 comments on commit be0be0a

Please sign in to comment.