Skip to content

Commit

Permalink
Merge pull request #204 from Kotlin/sort-grouped-df
Browse files Browse the repository at this point in the history
Sort grouped df
  • Loading branch information
Jolanrensen authored Dec 9, 2022
2 parents 6f3773c + 8a46ded commit b59cfb6
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ public interface SortDsl<out T> : ColumnsSelectionDsl<T> {
public fun <C> KProperty<C?>.nullsLast(flag: Boolean = true): ColumnSet<C?> = toColumnAccessor().nullsLast(flag)
}

/**
* [SortColumnsSelector] is used to express or select multiple columns to sort by, represented by [ColumnSet]`<C>`,
* using the context of [SortDsl]`<T>` as `this` and `it`.
*
* So:
* ```kotlin
* SortDsl<T>.(it: SortDsl<T>) -> ColumnSet<C>
* ```
*/
public typealias SortColumnsSelector<T, C> = Selector<SortDsl<T>, ColumnSet<C>>

// region DataColumn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,17 @@ internal open class DataFrameReceiver<T>(
private val unresolvedColumnsPolicy: UnresolvedColumnsPolicy
) : DataFrameReceiverBase<T>(source.unbox()), SingleColumn<DataRow<T>> {

private fun <R> DataColumn<R>?.check(path: ColumnPath): DataColumn<R>? =
private fun <R> DataColumn<R>?.check(path: ColumnPath): DataColumn<R> =
when (this) {
null -> when (unresolvedColumnsPolicy) {
UnresolvedColumnsPolicy.Create, UnresolvedColumnsPolicy.Skip -> MissingColumnGroup<Any>(path, this@DataFrameReceiver).asDataColumn().cast()
UnresolvedColumnsPolicy.Create, UnresolvedColumnsPolicy.Skip -> MissingColumnGroup<Any>(
path,
this@DataFrameReceiver
).asDataColumn().cast()

UnresolvedColumnsPolicy.Fail -> error("Column $path not found")
}

is MissingDataColumn -> this
is ColumnGroup<*> -> ColumnGroupWithParent(null, this).asDataColumn().cast()
else -> this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,29 @@ import org.jetbrains.kotlinx.dataframe.columns.UnresolvedColumnsPolicy
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
import org.jetbrains.kotlinx.dataframe.impl.columns.addPath
import org.jetbrains.kotlinx.dataframe.impl.columns.assertIsComparable
import org.jetbrains.kotlinx.dataframe.impl.columns.missing.MissingColumnGroup
import org.jetbrains.kotlinx.dataframe.impl.columns.resolve
import org.jetbrains.kotlinx.dataframe.impl.columns.toColumns
import org.jetbrains.kotlinx.dataframe.kind
import org.jetbrains.kotlinx.dataframe.nrow

internal fun <T, G> GroupBy<T, G>.sortByImpl(columns: SortColumnsSelector<G, *>): GroupBy<T, G> {
return toDataFrame()
@Suppress("UNCHECKED_CAST", "RemoveExplicitTypeArguments")
internal fun <T, G> GroupBy<T, G>.sortByImpl(columns: SortColumnsSelector<G, *>): GroupBy<T, G> =
toDataFrame()

// sort the individual groups by the columns specified
.update { groups }
.with { it.sortByImpl(UnresolvedColumnsPolicy.Skip, columns) }

// sort the groups by the columns specified (must be either be the keys column or "groups")
// will do nothing if the columns specified are not the keys column or "groups"
.sortByImpl(UnresolvedColumnsPolicy.Skip, columns as SortColumnsSelector<T, *>)
.asGroupBy { it.getFrameColumn(groups.name()).castFrameColumn() }
}

.asGroupBy { it.getFrameColumn(groups.name()).castFrameColumn<G>() }

internal fun <T, C> DataFrame<T>.sortByImpl(
unresolvedColumnsPolicy: UnresolvedColumnsPolicy = UnresolvedColumnsPolicy.Fail,
columns: SortColumnsSelector<T, C>
columns: SortColumnsSelector<T, C>,
): DataFrame<T> {
val sortColumns = getSortColumns(columns, unresolvedColumnsPolicy)
if (sortColumns.isEmpty()) return this
Expand Down Expand Up @@ -61,17 +68,17 @@ internal fun AnyCol.createComparator(nullsLast: Boolean): java.util.Comparator<I

internal fun <T, C> DataFrame<T>.getSortColumns(
columns: SortColumnsSelector<T, C>,
unresolvedColumnsPolicy: UnresolvedColumnsPolicy
): List<SortColumnDescriptor<*>> {
return columns.toColumns().resolve(this, unresolvedColumnsPolicy)
unresolvedColumnsPolicy: UnresolvedColumnsPolicy,
): List<SortColumnDescriptor<*>> =
columns.toColumns().resolve(this, unresolvedColumnsPolicy)
.filterNot { it.data is MissingColumnGroup<*> } // can appear using [DataColumn<R>?.check] with UnresolvedColumnsPolicy.Skip
.map {
when (val col = it.data) {
is SortColumnDescriptor<*> -> col
is ValueColumn<*> -> SortColumnDescriptor(col)
else -> throw IllegalStateException("Can not use ${col.kind} as sort column")
}
}
}

internal enum class SortFlag { Reversed, NullsLast }

Expand All @@ -86,12 +93,14 @@ internal fun <C> ColumnWithPath<C>.addFlag(flag: SortFlag): ColumnWithPath<C> {
SortFlag.NullsLast -> SortColumnDescriptor(col.column, col.direction, true)
}
}

is ValueColumn -> {
when (flag) {
SortFlag.Reversed -> SortColumnDescriptor(col, SortDirection.Desc)
SortFlag.NullsLast -> SortColumnDescriptor(col, SortDirection.Asc, true)
}
}

else -> throw IllegalArgumentException("Can not apply sort flag to column kind ${col.kind}")
}.addPath(path)
}
Expand All @@ -103,7 +112,7 @@ internal class ColumnsWithSortFlag<C>(val column: ColumnSet<C>, val flag: SortFl
internal class SortColumnDescriptor<C>(
val column: ValueColumn<C>,
val direction: SortDirection = SortDirection.Asc,
val nullsLast: Boolean = false
val nullsLast: Boolean = false,
) : ValueColumn<C> by column

internal enum class SortDirection { Asc, Desc }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package org.jetbrains.kotlinx.dataframe.api

import io.kotest.matchers.shouldBe
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.alsoDebug
import org.jetbrains.kotlinx.dataframe.io.read
import org.junit.Test

class SortGroupedDataframeTests {

@Test
fun `Sorted grouped iris dataset`() {
val irisData = DataFrame.read("src/test/resources/irisDataset.csv")
irisData.alsoDebug()

irisData.groupBy("variety").let {
it.sortBy("petal.length").toString() shouldBe
it.sortBy { it["petal.length"] }.toString()
}
}

enum class State {
Idle, Productive, Maintenance
}

@Test
fun test4() {
class Event(val toolId: String, val state: State, val timestamp: Long)

val tool1 = "tool_1"
val tool2 = "tool_2"
val tool3 = "tool_3"

val events = listOf(
Event(tool1, State.Idle, 0),
Event(tool1, State.Productive, 5),
Event(tool2, State.Idle, 0),
Event(tool2, State.Maintenance, 10),
Event(tool2, State.Idle, 20),
Event(tool3, State.Idle, 0),
Event(tool3, State.Productive, 25),
).toDataFrame()

val lastTimestamp = events.maxOf { getValue<Long>("timestamp") }
val groupBy = events
.groupBy("toolId")
.sortBy("timestamp")
.add("stateDuration") {
(next()?.getValue("timestamp") ?: lastTimestamp) - getValue<Long>("timestamp")
}

groupBy.toDataFrame().alsoDebug()
groupBy.schema().print()
groupBy.keys.print()
groupBy.keys[0].print()

val df1 = groupBy.updateGroups {
val missingValues = State.values().asList().toDataFrame {
"state" from { it }
}

val df = it
.fullJoin(missingValues, "state")
.fillNulls("stateDuration")
.with { 100L }

df.groupBy("state").sumFor("stateDuration")
}

df1.toDataFrame().alsoDebug().isNotEmpty() shouldBe true
}
}
151 changes: 151 additions & 0 deletions core/src/test/resources/irisDataset.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"sepal.length","sepal.width","petal.length","petal.width","variety"
5.1,3.5,1.4,.2,"Setosa"
4.9,3,1.4,.2,"Setosa"
4.7,3.2,1.3,.2,"Setosa"
4.6,3.1,1.5,.2,"Setosa"
5,3.6,1.4,.2,"Setosa"
5.4,3.9,1.7,.4,"Setosa"
4.6,3.4,1.4,.3,"Setosa"
5,3.4,1.5,.2,"Setosa"
4.4,2.9,1.4,.2,"Setosa"
4.9,3.1,1.5,.1,"Setosa"
5.4,3.7,1.5,.2,"Setosa"
4.8,3.4,1.6,.2,"Setosa"
4.8,3,1.4,.1,"Setosa"
4.3,3,1.1,.1,"Setosa"
5.8,4,1.2,.2,"Setosa"
5.7,4.4,1.5,.4,"Setosa"
5.4,3.9,1.3,.4,"Setosa"
5.1,3.5,1.4,.3,"Setosa"
5.7,3.8,1.7,.3,"Setosa"
5.1,3.8,1.5,.3,"Setosa"
5.4,3.4,1.7,.2,"Setosa"
5.1,3.7,1.5,.4,"Setosa"
4.6,3.6,1,.2,"Setosa"
5.1,3.3,1.7,.5,"Setosa"
4.8,3.4,1.9,.2,"Setosa"
5,3,1.6,.2,"Setosa"
5,3.4,1.6,.4,"Setosa"
5.2,3.5,1.5,.2,"Setosa"
5.2,3.4,1.4,.2,"Setosa"
4.7,3.2,1.6,.2,"Setosa"
4.8,3.1,1.6,.2,"Setosa"
5.4,3.4,1.5,.4,"Setosa"
5.2,4.1,1.5,.1,"Setosa"
5.5,4.2,1.4,.2,"Setosa"
4.9,3.1,1.5,.2,"Setosa"
5,3.2,1.2,.2,"Setosa"
5.5,3.5,1.3,.2,"Setosa"
4.9,3.6,1.4,.1,"Setosa"
4.4,3,1.3,.2,"Setosa"
5.1,3.4,1.5,.2,"Setosa"
5,3.5,1.3,.3,"Setosa"
4.5,2.3,1.3,.3,"Setosa"
4.4,3.2,1.3,.2,"Setosa"
5,3.5,1.6,.6,"Setosa"
5.1,3.8,1.9,.4,"Setosa"
4.8,3,1.4,.3,"Setosa"
5.1,3.8,1.6,.2,"Setosa"
4.6,3.2,1.4,.2,"Setosa"
5.3,3.7,1.5,.2,"Setosa"
5,3.3,1.4,.2,"Setosa"
7,3.2,4.7,1.4,"Versicolor"
6.4,3.2,4.5,1.5,"Versicolor"
6.9,3.1,4.9,1.5,"Versicolor"
5.5,2.3,4,1.3,"Versicolor"
6.5,2.8,4.6,1.5,"Versicolor"
5.7,2.8,4.5,1.3,"Versicolor"
6.3,3.3,4.7,1.6,"Versicolor"
4.9,2.4,3.3,1,"Versicolor"
6.6,2.9,4.6,1.3,"Versicolor"
5.2,2.7,3.9,1.4,"Versicolor"
5,2,3.5,1,"Versicolor"
5.9,3,4.2,1.5,"Versicolor"
6,2.2,4,1,"Versicolor"
6.1,2.9,4.7,1.4,"Versicolor"
5.6,2.9,3.6,1.3,"Versicolor"
6.7,3.1,4.4,1.4,"Versicolor"
5.6,3,4.5,1.5,"Versicolor"
5.8,2.7,4.1,1,"Versicolor"
6.2,2.2,4.5,1.5,"Versicolor"
5.6,2.5,3.9,1.1,"Versicolor"
5.9,3.2,4.8,1.8,"Versicolor"
6.1,2.8,4,1.3,"Versicolor"
6.3,2.5,4.9,1.5,"Versicolor"
6.1,2.8,4.7,1.2,"Versicolor"
6.4,2.9,4.3,1.3,"Versicolor"
6.6,3,4.4,1.4,"Versicolor"
6.8,2.8,4.8,1.4,"Versicolor"
6.7,3,5,1.7,"Versicolor"
6,2.9,4.5,1.5,"Versicolor"
5.7,2.6,3.5,1,"Versicolor"
5.5,2.4,3.8,1.1,"Versicolor"
5.5,2.4,3.7,1,"Versicolor"
5.8,2.7,3.9,1.2,"Versicolor"
6,2.7,5.1,1.6,"Versicolor"
5.4,3,4.5,1.5,"Versicolor"
6,3.4,4.5,1.6,"Versicolor"
6.7,3.1,4.7,1.5,"Versicolor"
6.3,2.3,4.4,1.3,"Versicolor"
5.6,3,4.1,1.3,"Versicolor"
5.5,2.5,4,1.3,"Versicolor"
5.5,2.6,4.4,1.2,"Versicolor"
6.1,3,4.6,1.4,"Versicolor"
5.8,2.6,4,1.2,"Versicolor"
5,2.3,3.3,1,"Versicolor"
5.6,2.7,4.2,1.3,"Versicolor"
5.7,3,4.2,1.2,"Versicolor"
5.7,2.9,4.2,1.3,"Versicolor"
6.2,2.9,4.3,1.3,"Versicolor"
5.1,2.5,3,1.1,"Versicolor"
5.7,2.8,4.1,1.3,"Versicolor"
6.3,3.3,6,2.5,"Virginica"
5.8,2.7,5.1,1.9,"Virginica"
7.1,3,5.9,2.1,"Virginica"
6.3,2.9,5.6,1.8,"Virginica"
6.5,3,5.8,2.2,"Virginica"
7.6,3,6.6,2.1,"Virginica"
4.9,2.5,4.5,1.7,"Virginica"
7.3,2.9,6.3,1.8,"Virginica"
6.7,2.5,5.8,1.8,"Virginica"
7.2,3.6,6.1,2.5,"Virginica"
6.5,3.2,5.1,2,"Virginica"
6.4,2.7,5.3,1.9,"Virginica"
6.8,3,5.5,2.1,"Virginica"
5.7,2.5,5,2,"Virginica"
5.8,2.8,5.1,2.4,"Virginica"
6.4,3.2,5.3,2.3,"Virginica"
6.5,3,5.5,1.8,"Virginica"
7.7,3.8,6.7,2.2,"Virginica"
7.7,2.6,6.9,2.3,"Virginica"
6,2.2,5,1.5,"Virginica"
6.9,3.2,5.7,2.3,"Virginica"
5.6,2.8,4.9,2,"Virginica"
7.7,2.8,6.7,2,"Virginica"
6.3,2.7,4.9,1.8,"Virginica"
6.7,3.3,5.7,2.1,"Virginica"
7.2,3.2,6,1.8,"Virginica"
6.2,2.8,4.8,1.8,"Virginica"
6.1,3,4.9,1.8,"Virginica"
6.4,2.8,5.6,2.1,"Virginica"
7.2,3,5.8,1.6,"Virginica"
7.4,2.8,6.1,1.9,"Virginica"
7.9,3.8,6.4,2,"Virginica"
6.4,2.8,5.6,2.2,"Virginica"
6.3,2.8,5.1,1.5,"Virginica"
6.1,2.6,5.6,1.4,"Virginica"
7.7,3,6.1,2.3,"Virginica"
6.3,3.4,5.6,2.4,"Virginica"
6.4,3.1,5.5,1.8,"Virginica"
6,3,4.8,1.8,"Virginica"
6.9,3.1,5.4,2.1,"Virginica"
6.7,3.1,5.6,2.4,"Virginica"
6.9,3.1,5.1,2.3,"Virginica"
5.8,2.7,5.1,1.9,"Virginica"
6.8,3.2,5.9,2.3,"Virginica"
6.7,3.3,5.7,2.5,"Virginica"
6.7,3,5.2,2.3,"Virginica"
6.3,2.5,5,1.9,"Virginica"
6.5,3,5.2,2,"Virginica"
6.2,3.4,5.4,2.3,"Virginica"
5.9,3,5.1,1.8,"Virginica"

0 comments on commit b59cfb6

Please sign in to comment.