Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Update.asFrame now takes filter into account. #283

Merged
merged 4 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.dataframe.RowValueExpression
import org.jetbrains.kotlinx.dataframe.RowValueFilter
import org.jetbrains.kotlinx.dataframe.Selector
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.impl.api.asFrameImpl
import org.jetbrains.kotlinx.dataframe.impl.api.updateImpl
import org.jetbrains.kotlinx.dataframe.impl.api.updateWithValuePerColumnImpl
import org.jetbrains.kotlinx.dataframe.impl.columns.toColumnSet
Expand Down Expand Up @@ -57,7 +58,7 @@ public infix fun <T, C> Update<T, C>.with(expression: UpdateExpression<T, C, C?>
}

public infix fun <T, C, R> Update<T, DataRow<C>>.asFrame(expression: DataFrameExpression<C, DataFrame<R>>): DataFrame<T> =
df.replace(columns).with { it.asColumnGroup().let { expression(it, it) }.asColumnGroup(it.name()) }
asFrameImpl(expression)

public fun <T, C> Update<T, C>.asNullable(): Update<T, C?> = this as Update<T, C?>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,28 @@ import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.AnyRow
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataFrameExpression
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.RowValueFilter
import org.jetbrains.kotlinx.dataframe.Selector
import org.jetbrains.kotlinx.dataframe.api.AddDataRow
import org.jetbrains.kotlinx.dataframe.api.Update
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.asDataFrame
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.indices
import org.jetbrains.kotlinx.dataframe.api.isEmpty
import org.jetbrains.kotlinx.dataframe.api.name
import org.jetbrains.kotlinx.dataframe.api.replace
import org.jetbrains.kotlinx.dataframe.api.toColumn
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.api.with
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
import org.jetbrains.kotlinx.dataframe.columns.size
import org.jetbrains.kotlinx.dataframe.impl.columns.AddDataRowImpl
import org.jetbrains.kotlinx.dataframe.impl.createDataCollector
import org.jetbrains.kotlinx.dataframe.index
import org.jetbrains.kotlinx.dataframe.type
import kotlin.reflect.full.isSubclassOf
import kotlin.reflect.full.withNullability
Expand All @@ -40,10 +46,48 @@ internal fun <T, C> Update<T, C>.updateWithValuePerColumnImpl(selector: Selector
}
}

/**
* Implementation for Update As Frame:
* Replaces selected column groups with the result of the expression only where the filter is true.
*/
internal fun <T, C, R> Update<T, DataRow<C>>.asFrameImpl(expression: DataFrameExpression<C, DataFrame<R>>): DataFrame<T> =
if (df.isEmpty()) df
else df.replace(columns).with {
// First, we create an updated column group with the result of the expression
val srcColumnGroup = it.asColumnGroup()
val updatedColumnGroup = srcColumnGroup
.asDataFrame()
.let { expression(it, it) }
.asColumnGroup(srcColumnGroup.name())

if (filter == null) {
zaleslaw marked this conversation as resolved.
Show resolved Hide resolved
// If there is no filter, we simply return the updated column group
updatedColumnGroup
} else {
// If there is a filter, then we replace the rows of the source column group with the updated column group
// only if they satisfy the filter
srcColumnGroup.replaceRowsIf(from = updatedColumnGroup) {
val srcRow = df[it.index]
val srcValue = srcRow[srcColumnGroup]

filter.invoke(srcRow, srcValue)
}
}
}

private fun <C, R> ColumnGroup<C>.replaceRowsIf(
from: ColumnGroup<R>,
condition: (DataRow<C>) -> Boolean = { true },
): ColumnGroup<C> = values()
.map { if (condition(it)) from[it.index] else it }
.toColumn(name)
.asColumnGroup()
.cast()

internal fun <T, C> DataColumn<C>.updateImpl(
df: DataFrame<T>,
filter: RowValueFilter<T, C>?,
expression: (AddDataRow<T>, DataColumn<C>, C) -> C?
expression: (AddDataRow<T>, DataColumn<C>, C) -> C?,
zaleslaw marked this conversation as resolved.
Show resolved Hide resolved
): DataColumn<C> {
val collector = createDataCollector<C>(size, type)
val src = this
Expand Down Expand Up @@ -75,6 +119,7 @@ internal fun <T> DataColumn<T>.updateWith(values: List<T>): DataColumn<T> = when
val groups = (values as List<AnyFrame>)
DataColumn.createFrameColumn(name, groups) as DataColumn<T>
}

is ColumnGroup<*> -> {
this.columns().mapIndexed { colIndex, col ->
val newValues = values.map {
Expand All @@ -88,6 +133,7 @@ internal fun <T> DataColumn<T>.updateWith(values: List<T>): DataColumn<T> = when
col.updateWith(newValues)
}.toDataFrame().let { DataColumn.createColumnGroup(name, it) } as DataColumn<T>
}

else -> {
var nulls = false
val kclass = type.jvmErasure
Expand Down
42 changes: 42 additions & 0 deletions core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/update.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.jetbrains.kotlinx.dataframe.api
import io.kotest.matchers.shouldBe
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.size
import org.junit.Test

class UpdateTests {
Expand All @@ -14,6 +15,47 @@ class UpdateTests {
df.update { col }.with { 2 } shouldBe df
}

@DataSchema
interface DataPart {
val a: Int
val b: String
}

@DataSchema
data class Data(
override val a: Int,
override val b: String,
val c: Boolean,
) : DataPart

@Test
fun `update asFrame`() {
val df = listOf(
Data(1, "a", true),
Data(2, "b", false),
).toDataFrame()

val group by columnGroup<DataPart>() named "Some Group"
val groupedDf = df.group { a and b }.into { group }

val res = groupedDf
.update { group }
.where { !c }
.asFrame {
// size should still be full df size
size.nrow shouldBe 2

// this will only apply to rows where `.where { !c }` holds
update { a }.with { 0 }
}

val (first, second) = res[{ group }].map { it.a }.toList()
first shouldBe 1
second shouldBe 0

res[{ group }].name() shouldBe "Some Group"
}

@DataSchema
interface SchemaA {
val i: Int?
Expand Down