Skip to content

Commit

Permalink
Merge pull request #401 from Kotlin/jupyter-any-detection
Browse files Browse the repository at this point in the history
Fix: Jupyter compile-time DF type not recognized
  • Loading branch information
Jolanrensen authored Jun 21, 2023
2 parents f1b245e + 7c3f204 commit ba4d321
Show file tree
Hide file tree
Showing 8 changed files with 350 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1786,12 +1786,10 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
*
* `df.`[select][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.select]` { `[colsOf][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.colsOf]`<`[String][String]`>().`[cols][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]`() }`
*
* @see [all]
*
*
* @param [predicate] A [ColumnFilter function][org.jetbrains.kotlinx.dataframe.ColumnFilter] that takes a [ColumnReference][org.jetbrains.kotlinx.dataframe.columns.ColumnReference] and returns a [Boolean].
* @return A [ColumnSet][org.jetbrains.kotlinx.dataframe.columns.ColumnSet] containing the columns that match the given [predicate].
*/
* @see [all] */
@Suppress("UNCHECKED_CAST")
public fun <C> ColumnSet<C>.cols(
predicate: ColumnFilter<C> = { true },
Expand Down Expand Up @@ -1829,12 +1827,10 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
*
* `df.`[select][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.select]` { `[colsOf][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.colsOf]`<`[String][String]`>().`[cols][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]`() }`
*
* @see [all]
*
*
* @param [predicate] A [ColumnFilter function][org.jetbrains.kotlinx.dataframe.ColumnFilter] that takes a [ColumnReference][org.jetbrains.kotlinx.dataframe.columns.ColumnReference] and returns a [Boolean].
* @return A [ColumnSet][org.jetbrains.kotlinx.dataframe.columns.ColumnSet] containing the columns that match the given [predicate].
*/
* @see [all] */
public operator fun <C> ColumnSet<C>.get(
predicate: ColumnFilter<C> = { true },
): TransformableColumnSet<C> = cols(predicate)
Expand Down Expand Up @@ -1928,12 +1924,10 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
*
* `df.`[select][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.select]` { myColumnGroup`[`[`][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]`{ ... }`[`]`][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]` }`
*
* @see [all]
*
*
* @param [predicate] A [ColumnFilter function][org.jetbrains.kotlinx.dataframe.ColumnFilter] that takes a [ColumnReference][org.jetbrains.kotlinx.dataframe.columns.ColumnReference] and returns a [Boolean].
* @return A [ColumnSet][org.jetbrains.kotlinx.dataframe.columns.ColumnSet] containing the columns that match the given [predicate].
*/
* @see [all] */
public fun SingleColumn<*>.cols(
predicate: ColumnFilter<*> = { true },
): TransformableColumnSet<*> = colsInternal(predicate)
Expand Down Expand Up @@ -1979,12 +1973,11 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
*
* `df.`[select][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.select]` { myColumnGroup`[`[`][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]`{ ... }`[`]`][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]` }`
*
* @see [all]
*
*
*
* @param [predicate] A [ColumnFilter function][org.jetbrains.kotlinx.dataframe.ColumnFilter] that takes a [ColumnReference][org.jetbrains.kotlinx.dataframe.columns.ColumnReference] and returns a [Boolean].
* @return A [ColumnSet][org.jetbrains.kotlinx.dataframe.columns.ColumnSet] containing the columns that match the given [predicate].
* @see [all]
*
*/
public operator fun SingleColumn<*>.get(
predicate: ColumnFilter<*> = { true },
Expand Down Expand Up @@ -2172,12 +2165,10 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
*
* `df.`[select][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.select]` { Type::columnGroup.`[cols][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]`() }`
*
* @see [all]
*
*
* @param [predicate] A [ColumnFilter function][org.jetbrains.kotlinx.dataframe.ColumnFilter] that takes a [ColumnReference][org.jetbrains.kotlinx.dataframe.columns.ColumnReference] and returns a [Boolean].
* @return A [ColumnSet][org.jetbrains.kotlinx.dataframe.columns.ColumnSet] containing the columns that match the given [predicate].
*/
* @see [all] */
public fun KProperty<*>.cols(
predicate: ColumnFilter<*> = { true },
): TransformableColumnSet<*> = colGroup(this).cols(predicate)
Expand Down Expand Up @@ -2212,12 +2203,10 @@ public interface ColumnsSelectionDsl<out T> : ColumnSelectionDsl<T>, SingleColum
*
* `df.`[select][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.select]` { Type::columnGroup.`[cols][org.jetbrains.kotlinx.dataframe.api.ColumnsSelectionDsl.cols]`() }`
*
* @see [all]
*
*
* @param [predicate] A [ColumnFilter function][org.jetbrains.kotlinx.dataframe.ColumnFilter] that takes a [ColumnReference][org.jetbrains.kotlinx.dataframe.columns.ColumnReference] and returns a [Boolean].
* @return A [ColumnSet][org.jetbrains.kotlinx.dataframe.columns.ColumnSet] containing the columns that match the given [predicate].
*/
* @see [all] */
public operator fun KProperty<*>.get(
predicate: ColumnFilter<*> = { true },
): TransformableColumnSet<Any?> = cols(predicate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,33 @@ import org.jetbrains.dataframe.impl.codeGen.ReplCodeGenerator
import org.jetbrains.kotlinx.dataframe.AnyCol
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.AnyRow
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.api.Convert
import org.jetbrains.kotlinx.dataframe.api.FormattedFrame
import org.jetbrains.kotlinx.dataframe.api.Gather
import org.jetbrains.kotlinx.dataframe.api.GroupBy
import org.jetbrains.kotlinx.dataframe.api.Merge
import org.jetbrains.kotlinx.dataframe.api.Pivot
import org.jetbrains.kotlinx.dataframe.api.PivotGroupBy
import org.jetbrains.kotlinx.dataframe.api.ReducedGroupBy
import org.jetbrains.kotlinx.dataframe.api.ReducedPivot
import org.jetbrains.kotlinx.dataframe.api.ReducedPivotGroupBy
import org.jetbrains.kotlinx.dataframe.api.Split
import org.jetbrains.kotlinx.dataframe.api.SplitWithTransform
import org.jetbrains.kotlinx.dataframe.api.Update
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.asDataFrame
import org.jetbrains.kotlinx.dataframe.api.columnsCount
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
import org.jetbrains.kotlinx.dataframe.api.frames
import org.jetbrains.kotlinx.dataframe.api.into
import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
import org.jetbrains.kotlinx.dataframe.api.name
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.api.values
import org.jetbrains.kotlinx.dataframe.codeGen.CodeWithConverter
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
Expand All @@ -31,6 +55,7 @@ import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration
import org.jetbrains.kotlinx.jupyter.api.libraries.resources
import kotlin.reflect.KClass
import kotlin.reflect.KProperty
import kotlin.reflect.KType
import kotlin.reflect.full.isSubtypeOf

/** Users will get an error if their Kotlin Jupyter kernel is older than this version. */
Expand All @@ -45,6 +70,101 @@ internal class Integration(

val version = options["v"]

private fun KotlinKernelHost.execute(codeWithConverter: CodeWithConverter, argument: String): VariableName? {
val code = codeWithConverter.with(argument)
return if (code.isNotBlank()) {
val result = execute(code)
if (codeWithConverter.hasConverter) {
result.name
} else null
} else null
}

private fun KotlinKernelHost.execute(
codeWithConverter: CodeWithConverter,
property: KProperty<*>,
type: KType,
): VariableName? {
val variableName = "(${property.name}${if (property.returnType.isMarkedNullable) "!!" else ""} as $type)"
return execute(codeWithConverter, variableName)
}

private fun KotlinKernelHost.updateImportDataSchemaVariable(
importDataSchema: ImportDataSchema,
property: KProperty<*>,
): VariableName? {
val formats = supportedFormats.filterIsInstance<SupportedCodeGenerationFormat>()
val name = property.name + "DataSchema"
return when (
val codeGenResult = CodeGenerator.urlCodeGenReader(importDataSchema.url, name, formats, true)
) {
is CodeGenerationReadResult.Success -> {
val readDfMethod = codeGenResult.getReadDfMethod(importDataSchema.url.toExternalForm())
val code = readDfMethod.additionalImports.joinToString("\n") +
"\n" +
codeGenResult.code

execute(code)
execute("""DISPLAY("Data schema successfully imported as ${property.name}: $name")""")

name
}

is CodeGenerationReadResult.Error -> {
execute("""DISPLAY("Failed to read data schema from ${importDataSchema.url}: ${codeGenResult.reason}")""")
null
}
}
}

private fun KotlinKernelHost.updateAnyFrameVariable(
df: AnyFrame,
property: KProperty<*>,
codeGen: ReplCodeGenerator,
): VariableName? = execute(
codeWithConverter = codeGen.process(df, property),
property = property,
type = DataFrame::class.createStarProjectedType(false),

)

private fun KotlinKernelHost.updateAnyRowVariable(
row: AnyRow,
property: KProperty<*>,
codeGen: ReplCodeGenerator,
): VariableName? = execute(
codeWithConverter = codeGen.process(row, property),
property = property,
type = DataRow::class.createStarProjectedType(false),
)

private fun KotlinKernelHost.updateColumnGroupVariable(
col: ColumnGroup<*>,
property: KProperty<*>,
codeGen: ReplCodeGenerator,
): VariableName? = execute(
codeWithConverter = codeGen.process(col.asDataFrame(), property),
property = property,
type = ColumnGroup::class.createStarProjectedType(false),
)

private fun KotlinKernelHost.updateAnyColVariable(
col: AnyCol,
property: KProperty<*>,
codeGen: ReplCodeGenerator,
): VariableName? = if (col.isColumnGroup()) {
val codeWithConverter = codeGen.process(col.asColumnGroup().asDataFrame(), property).let { c ->
CodeWithConverter(c.declarations) { c.converter("$it.asColumnGroup()") }
}
execute(
codeWithConverter = codeWithConverter,
property = property,
type = DataColumn::class.createStarProjectedType(false),
)
} else {
null
}

override fun Builder.onLoaded() {
if (version != null) {
dependencies(
Expand Down Expand Up @@ -152,65 +272,17 @@ internal class Integration(
import("org.jetbrains.kotlinx.dataframe.dataTypes.*")
import("org.jetbrains.kotlinx.dataframe.impl.codeGen.urlCodeGenReader")

fun KotlinKernelHost.execute(codeWithConverter: CodeWithConverter, argument: String): VariableName? {
val code = codeWithConverter.with(argument)
return if (code.isNotBlank()) {
val result = execute(code)
if (codeWithConverter.hasConverter) {
result.name
} else null
} else null
}

fun KotlinKernelHost.execute(codeWithConverter: CodeWithConverter, property: KProperty<*>): VariableName? {
val variableName = property.name + if (property.returnType.isMarkedNullable) "!!" else ""
return execute(codeWithConverter, variableName)
}

updateVariable<ImportDataSchema> { importDataSchema, property ->
val formats = supportedFormats.filterIsInstance<SupportedCodeGenerationFormat>()
val name = property.name + "DataSchema"
when (val codeGenResult = CodeGenerator.urlCodeGenReader(importDataSchema.url, name, formats, true)) {
is CodeGenerationReadResult.Success -> {
val readDfMethod = codeGenResult.getReadDfMethod(importDataSchema.url.toExternalForm())
val code = readDfMethod.additionalImports.joinToString("\n") +
"\n" +
codeGenResult.code

execute(code)
execute("""DISPLAY("Data schema successfully imported as ${property.name}: $name")""")

name
}

is CodeGenerationReadResult.Error -> {
execute("""DISPLAY("Failed to read data schema from ${importDataSchema.url}: ${codeGenResult.reason}")""")
null
}
updateVariable<Any> { instance, property ->
when (instance) {
is AnyCol -> updateAnyColVariable(instance, property, codeGen)
is ColumnGroup<*> -> updateColumnGroupVariable(instance, property, codeGen)
is AnyRow -> updateAnyRowVariable(instance, property, codeGen)
is AnyFrame -> updateAnyFrameVariable(instance, property, codeGen)
is ImportDataSchema -> updateImportDataSchemaVariable(instance, property)
else -> null
}
}

updateVariable<AnyFrame> { df, property ->
execute(codeGen.process(df, property), property)
}

updateVariable<AnyRow> { row, property ->
execute(codeGen.process(row, property), property)
}

updateVariable<ColumnGroup<*>> { col, property ->
execute(codeGen.process(col.asDataFrame(), property), property)
}

updateVariable<AnyCol> { col, property ->
if (col.isColumnGroup()) {
val codeWithConverter = codeGen.process(col.asColumnGroup().asDataFrame(), property).let { c ->
CodeWithConverter(c.declarations) { c.converter("$it.asColumnGroup()") }
}
execute(codeWithConverter, property)
} else null
}

fun KotlinKernelHost.addDataSchemas(classes: List<KClass<*>>) {
val code = classes.joinToString("\n") {
codeGen.process(it)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.jetbrains.kotlinx.dataframe.jupyter

import org.intellij.lang.annotations.Language
import org.jetbrains.kotlinx.jupyter.api.Code
import org.junit.Test

Expand All @@ -11,9 +12,20 @@ class CodeGenerationTests : DataFrameJupyterTest() {
}
}

@Test
fun `Type erased dataframe`() {
@Language("kts")
val a = """
fun create(): Any? = dataFrameOf("a")(1)
val df = create()
df.a
""".checkCompilation()
}

@Test
fun `nullable dataframe`() {
"""
@Language("kts")
val a = """
fun create(): AnyFrame? = dataFrameOf("a")(1)
val df = create()
df.a
Expand All @@ -22,7 +34,8 @@ class CodeGenerationTests : DataFrameJupyterTest() {

@Test
fun `nullable columnGroup`() {
"""
@Language("kts")
val a = """
fun create(): AnyCol? = dataFrameOf("a")(1).asColumnGroup().asDataColumn()
val col = create()
col.a
Expand All @@ -31,7 +44,8 @@ class CodeGenerationTests : DataFrameJupyterTest() {

@Test
fun `nullable dataRow`() {
"""
@Language("kts")
val a = """
fun create(): AnyRow? = dataFrameOf("a")(1).single()
val row = create()
row.a
Expand Down
Loading

0 comments on commit ba4d321

Please sign in to comment.