Skip to content

Commit

Permalink
Merge pull request #283 from Kotlin/asFrame-fix
Browse files Browse the repository at this point in the history
[Fix] Update.asFrame now takes filter into account.
  • Loading branch information
Jolanrensen authored Mar 2, 2023
2 parents c353922 + 1742f43 commit a089dcb
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.dataframe.RowValueExpression
import org.jetbrains.kotlinx.dataframe.RowValueFilter
import org.jetbrains.kotlinx.dataframe.Selector
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.impl.api.asFrameImpl
import org.jetbrains.kotlinx.dataframe.impl.api.updateImpl
import org.jetbrains.kotlinx.dataframe.impl.api.updateWithValuePerColumnImpl
import org.jetbrains.kotlinx.dataframe.impl.columns.toColumnSet
Expand Down Expand Up @@ -57,7 +58,7 @@ public infix fun <T, C> Update<T, C>.with(expression: UpdateExpression<T, C, C?>
}

public infix fun <T, C, R> Update<T, DataRow<C>>.asFrame(expression: DataFrameExpression<C, DataFrame<R>>): DataFrame<T> =
df.replace(columns).with { it.asColumnGroup().let { expression(it, it) }.asColumnGroup(it.name()) }
asFrameImpl(expression)

public fun <T, C> Update<T, C>.asNullable(): Update<T, C?> = this as Update<T, C?>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,28 @@ import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.AnyRow
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataFrameExpression
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.RowValueFilter
import org.jetbrains.kotlinx.dataframe.Selector
import org.jetbrains.kotlinx.dataframe.api.AddDataRow
import org.jetbrains.kotlinx.dataframe.api.Update
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.asDataFrame
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.indices
import org.jetbrains.kotlinx.dataframe.api.isEmpty
import org.jetbrains.kotlinx.dataframe.api.name
import org.jetbrains.kotlinx.dataframe.api.replace
import org.jetbrains.kotlinx.dataframe.api.toColumn
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.api.with
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
import org.jetbrains.kotlinx.dataframe.columns.size
import org.jetbrains.kotlinx.dataframe.impl.columns.AddDataRowImpl
import org.jetbrains.kotlinx.dataframe.impl.createDataCollector
import org.jetbrains.kotlinx.dataframe.index
import org.jetbrains.kotlinx.dataframe.type
import kotlin.reflect.full.isSubclassOf
import kotlin.reflect.full.withNullability
Expand All @@ -40,10 +46,48 @@ internal fun <T, C> Update<T, C>.updateWithValuePerColumnImpl(selector: Selector
}
}

/**
* Implementation for Update As Frame:
* Replaces selected column groups with the result of the expression only where the filter is true.
*/
internal fun <T, C, R> Update<T, DataRow<C>>.asFrameImpl(expression: DataFrameExpression<C, DataFrame<R>>): DataFrame<T> =
if (df.isEmpty()) df
else df.replace(columns).with {
// First, we create an updated column group with the result of the expression
val srcColumnGroup = it.asColumnGroup()
val updatedColumnGroup = srcColumnGroup
.asDataFrame()
.let { expression(it, it) }
.asColumnGroup(srcColumnGroup.name())

if (filter == null) {
// If there is no filter, we simply return the updated column group
updatedColumnGroup
} else {
// If there is a filter, then we replace the rows of the source column group with the updated column group
// only if they satisfy the filter
srcColumnGroup.replaceRowsIf(from = updatedColumnGroup) {
val srcRow = df[it.index]
val srcValue = srcRow[srcColumnGroup]

filter.invoke(srcRow, srcValue)
}
}
}

private fun <C, R> ColumnGroup<C>.replaceRowsIf(
from: ColumnGroup<R>,
condition: (DataRow<C>) -> Boolean = { true },
): ColumnGroup<C> = values()
.map { if (condition(it)) from[it.index] else it }
.toColumn(name)
.asColumnGroup()
.cast()

internal fun <T, C> DataColumn<C>.updateImpl(
df: DataFrame<T>,
filter: RowValueFilter<T, C>?,
expression: (AddDataRow<T>, DataColumn<C>, C) -> C?
expression: (AddDataRow<T>, DataColumn<C>, C) -> C?,
): DataColumn<C> {
val collector = createDataCollector<C>(size, type)
val src = this
Expand Down Expand Up @@ -75,6 +119,7 @@ internal fun <T> DataColumn<T>.updateWith(values: List<T>): DataColumn<T> = when
val groups = (values as List<AnyFrame>)
DataColumn.createFrameColumn(name, groups) as DataColumn<T>
}

is ColumnGroup<*> -> {
this.columns().mapIndexed { colIndex, col ->
val newValues = values.map {
Expand All @@ -88,6 +133,7 @@ internal fun <T> DataColumn<T>.updateWith(values: List<T>): DataColumn<T> = when
col.updateWith(newValues)
}.toDataFrame().let { DataColumn.createColumnGroup(name, it) } as DataColumn<T>
}

else -> {
var nulls = false
val kclass = type.jvmErasure
Expand Down
42 changes: 42 additions & 0 deletions core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/update.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.jetbrains.kotlinx.dataframe.api
import io.kotest.matchers.shouldBe
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.size
import org.junit.Test

class UpdateTests {
Expand All @@ -14,6 +15,47 @@ class UpdateTests {
df.update { col }.with { 2 } shouldBe df
}

@DataSchema
interface DataPart {
val a: Int
val b: String
}

@DataSchema
data class Data(
override val a: Int,
override val b: String,
val c: Boolean,
) : DataPart

@Test
fun `update asFrame`() {
val df = listOf(
Data(1, "a", true),
Data(2, "b", false),
).toDataFrame()

val group by columnGroup<DataPart>() named "Some Group"
val groupedDf = df.group { a and b }.into { group }

val res = groupedDf
.update { group }
.where { !c }
.asFrame {
// size should still be full df size
size.nrow shouldBe 2

// this will only apply to rows where `.where { !c }` holds
update { a }.with { 0 }
}

val (first, second) = res[{ group }].map { it.a }.toList()
first shouldBe 1
second shouldBe 0

res[{ group }].name() shouldBe "Some Group"
}

@DataSchema
interface SchemaA {
val i: Int?
Expand Down

0 comments on commit a089dcb

Please sign in to comment.