-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement DataFrame API and its tests
- Loading branch information
Showing
4 changed files
with
113 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package io.hqew.kquery.logical | ||
|
||
import io.hqew.kquery.datatypes.Schema | ||
|
||
interface DataFrame { | ||
|
||
/** Apply a projection **/ | ||
fun project(expr: List<LogicalExpr>): DataFrame | ||
|
||
/** Apply a filter **/ | ||
fun filter(expr: LogicalExpr): DataFrame | ||
|
||
/** Aggregate **/ | ||
fun aggregate(groupBy: List<LogicalExpr>, aggregateExpr: List<AggregateExpr>): DataFrame | ||
|
||
/** getSchema **/ | ||
fun schema(): Schema | ||
|
||
/** get the logical plan **/ | ||
fun logicalPlan(): LogicalPlan | ||
} | ||
|
||
class DataFrameImpl(private val plan: LogicalPlan): DataFrame { | ||
override fun project(expr: List<LogicalExpr>): DataFrame { | ||
return DataFrameImpl(Projection(plan, expr)) | ||
} | ||
|
||
override fun filter(expr: LogicalExpr): DataFrame { | ||
return DataFrameImpl(Selection(plan, expr)) | ||
} | ||
|
||
override fun aggregate(groupBy: List<LogicalExpr>, aggregateExpr: List<AggregateExpr>): DataFrame { | ||
return DataFrameImpl(Aggregate(plan, groupBy, aggregateExpr)) | ||
} | ||
|
||
override fun schema(): Schema { | ||
return plan.schema() | ||
} | ||
|
||
override fun logicalPlan(): LogicalPlan { | ||
return plan | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package io.hqew.kquery.logical | ||
|
||
import io.hqew.kquery.datasource.CsvDataSource | ||
import org.junit.Test | ||
import org.junit.jupiter.api.TestInstance | ||
import java.io.File | ||
import kotlin.test.assertEquals | ||
|
||
@TestInstance(TestInstance.Lifecycle.PER_CLASS) | ||
class DataFrameTest { | ||
|
||
@Test | ||
fun `test DataFrame build`() { | ||
val df = csv().project(listOf(Column("id"))) | ||
|
||
assertEquals(df.schema().fields.size, 1) | ||
assertEquals(df.logicalPlan().toString(), "Projection: #id") | ||
} | ||
|
||
@Test | ||
fun `test DataFrame filter`() { | ||
val df = csv() | ||
.filter(Column("first_name").eq(LiteralString("John"))) | ||
.project(listOf(Column("id"))) | ||
|
||
assertEquals(df.schema().fields.size, 1) | ||
println(format(df.logicalPlan())) | ||
assertEquals(format(df.logicalPlan()), "Projection: #id\n\tSelection: #first_name = 'John'\n\t\tScan: employee.csv; projection=None\n") | ||
} | ||
|
||
@Test | ||
fun `multiplier and alias`() { | ||
|
||
val df = | ||
csv() | ||
.filter(col("state") eq lit("CO")) | ||
.project( | ||
listOf( | ||
col("id"), | ||
col("first_name"), | ||
col("last_name"), | ||
col("salary"), | ||
(col("salary") mult lit(0.1)) alias "bonus")) | ||
.filter(col("bonus") gt lit(1000)) | ||
|
||
val expected = | ||
"Selection: #bonus > 1000\n" + | ||
"\tProjection: #id, #first_name, #last_name, #salary, #salary * 0.1 as bonus\n" + | ||
"\t\tSelection: #state = 'CO'\n" + | ||
"\t\t\tScan: employee.csv; projection=None\n" | ||
|
||
val actual = format(df.logicalPlan()) | ||
|
||
assertEquals(expected, actual) | ||
} | ||
|
||
private fun csv(): DataFrame { | ||
val csvSource = CsvDataSource( | ||
File("../testdata", "employee.csv").absolutePath, | ||
schema = null, | ||
hasHeaders = true, | ||
batchSize = 1024, | ||
) | ||
|
||
return DataFrameImpl(Scan("employee.csv", csvSource, listOf())) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters