Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Converting operations #133

Merged
merged 7 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import org.jetbrains.kotlinx.dataframe.RowValueExpression
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.dataTypes.IFRAME
import org.jetbrains.kotlinx.dataframe.dataTypes.IMG
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConversionException
import org.jetbrains.kotlinx.dataframe.impl.api.Parsers
import org.jetbrains.kotlinx.dataframe.impl.api.convertRowColumnImpl
import org.jetbrains.kotlinx.dataframe.impl.api.convertToTypeImpl
Expand All @@ -30,7 +31,7 @@ import org.jetbrains.kotlinx.dataframe.io.toDataFrame
import java.math.BigDecimal
import java.net.URL
import java.time.LocalTime
import java.util.*
import java.util.Locale
import kotlin.reflect.KProperty
import kotlin.reflect.KType
import kotlin.reflect.typeOf
Expand Down Expand Up @@ -99,7 +100,11 @@ public fun <T, C> Convert<T, C>.to(columnConverter: DataFrame<T>.(DataColumn<C>)
df.replace(columns).with { columnConverter(df, it) }

public inline fun <reified C> AnyCol.convertTo(): DataColumn<C> = convertTo(typeOf<C>()) as DataColumn<C>
public fun AnyCol.convertTo(newType: KType): AnyCol = convertToTypeImpl(newType)
public fun AnyCol.convertTo(newType: KType): AnyCol {
if (this.type() == typeOf<String>() && newType == typeOf<Double>()) return (this as DataColumn<String>).convertToDouble()
if (this.type() == typeOf<String?>() && newType == typeOf<Double?>()) return (this as DataColumn<String?>).convertToDouble()
return convertToTypeImpl(newType)
}

@JvmName("convertToLocalDateTimeFromT")
public fun <T : Any> DataColumn<T>.convertToLocalDateTime(): DataColumn<LocalDateTime> = convertTo()
Expand All @@ -125,6 +130,37 @@ public fun <T : Any> DataColumn<T?>.convertToString(): DataColumn<String?> = con
public fun <T : Any> DataColumn<T>.convertToDouble(): DataColumn<Double> = convertTo()
public fun <T : Any> DataColumn<T?>.convertToDouble(): DataColumn<Double?> = convertTo()

/**
* Parse String column to Double considering locale (number format).
* If [locale] parameter is defined, it's number format is used for parsing.
* If [locale] parameter is null, the current system locale is used. If column can not be parsed, then POSIX format is used.
*/
@JvmName("convertToDoubleFromString")
public fun DataColumn<String>.convertToDouble(locale: Locale? = null): DataColumn<Double> {
return this.castToNullable().convertToDouble(locale).castToNotNullable()
}

/**
* Parse String column to Double considering locale (number format).
* If [locale] parameter is defined, it's number format is used for parsing.
* If [locale] parameter is null, the current system locale is used. If column can not be parsed, then POSIX format is used.
*/
@JvmName("convertToDoubleFromStringNullable")
public fun DataColumn<String?>.convertToDouble(locale: Locale? = null): DataColumn<Double?> {
if (locale != null) {
val explicitParser = Parsers.getDoubleParser(locale)
return map { it?.let { explicitParser(it.trim()) ?: throw TypeConversionException(it, typeOf<String>(), typeOf<Double>()) } }
} else {
return try {
val defaultParser = Parsers.getDoubleParser()
map { it?.let { defaultParser(it.trim()) ?: throw TypeConversionException(it, typeOf<String>(), typeOf<Double>()) } }
} catch (e: TypeConversionException) {
val posixParser = Parsers.getDoubleParser(Locale.forLanguageTag("C.UTF-8"))
map { it?.let { posixParser(it.trim()) ?: throw TypeConversionException(it, typeOf<String>(), typeOf<Double>()) } }
}
}
}

@JvmName("convertToFloatFromT")
public fun <T : Any> DataColumn<T>.convertToFloat(): DataColumn<Float> = convertTo()
public fun <T : Any> DataColumn<T?>.convertToFloat(): DataColumn<Float?> = convertTo()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.jetbrains.kotlinx.dataframe.type
import java.math.BigDecimal
import java.net.URL
import java.time.LocalTime
import java.util.Locale
import kotlin.math.roundToInt
import kotlin.math.roundToLong
import kotlin.reflect.KType
Expand Down Expand Up @@ -81,15 +82,15 @@ internal fun AnyCol.convertToTypeImpl(to: KType): AnyCol {
return when {
from == to -> this
from.isSubtypeOf(to) -> (this as DataColumnInternal<*>).changeType(to.withNullability(hasNulls()))
else -> when (val converter = getConverter(from, to)) {
else -> when (val converter = getConverter(from, to, ParserOptions(locale = Locale.getDefault()))) {
null -> when (from.classifier) {
Any::class, Number::class, java.io.Serializable::class -> {
// find converter for every value
val values = values.map {
it?.let {
val clazz = it.javaClass.kotlin
val type = clazz.createStarProjectedType(false)
val converter = getConverter(type, to) ?: throw TypeConverterNotFoundException(from, to)
val converter = getConverter(type, to, ParserOptions(locale = Locale.getDefault())) ?: throw TypeConverterNotFoundException(from, to)
converter(it)
}.checkNulls()
}
Expand All @@ -107,9 +108,9 @@ internal fun AnyCol.convertToTypeImpl(to: KType): AnyCol {
}
}

internal val convertersCache = mutableMapOf<Pair<KType, KType>, TypeConverter?>()
internal val convertersCache = mutableMapOf<Triple<KType, KType, ParserOptions?>, TypeConverter?>()

internal fun getConverter(from: KType, to: KType): TypeConverter? = convertersCache.getOrPut(from to to) { createConverter(from, to) }
internal fun getConverter(from: KType, to: KType, options: ParserOptions? = null): TypeConverter? = convertersCache.getOrPut(Triple(from, to, options)) { createConverter(from, to, options) }

internal typealias TypeConverter = (Any) -> Any?

Expand Down Expand Up @@ -205,6 +206,7 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
Byte::class -> convert<Number> { it.toByte() }
Short::class -> convert<Number> { it.toShort() }
Long::class -> convert<Number> { it.toLong() }
Boolean::class -> convert<Number> { it.toDouble() != 0.0 }
else -> null
}
Int::class -> when (toClass) {
Expand All @@ -214,6 +216,7 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
Short::class -> convert<Int> { it.toShort() }
Long::class -> convert<Int> { it.toLong() }
BigDecimal::class -> convert<Int> { it.toBigDecimal() }
Boolean::class -> convert<Int> { it != 0 }
LocalDateTime::class -> convert<Int> { it.toLong().toLocalDateTime(defaultTimeZone) }
LocalDate::class -> convert<Int> { it.toLong().toLocalDate(defaultTimeZone) }
java.time.LocalDateTime::class -> convert<Long> { it.toLocalDateTime(defaultTimeZone).toJavaLocalDateTime() }
Expand All @@ -227,6 +230,7 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
Long::class -> convert<Double> { it.roundToLong() }
Short::class -> convert<Double> { it.roundToInt().toShort() }
BigDecimal::class -> convert<Double> { it.toBigDecimal() }
Boolean::class -> convert<Double> { it != 0.0 }
else -> null
}
Long::class -> when (toClass) {
Expand All @@ -236,6 +240,7 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
Short::class -> convert<Long> { it.toShort() }
Int::class -> convert<Long> { it.toInt() }
BigDecimal::class -> convert<Long> { it.toBigDecimal() }
Boolean::class -> convert<Long> { it != 0L }
LocalDateTime::class -> convert<Long> { it.toLocalDateTime(defaultTimeZone) }
LocalDate::class -> convert<Long> { it.toLocalDate(defaultTimeZone) }
Instant::class -> convert<Long> { Instant.fromEpochMilliseconds(it) }
Expand Down Expand Up @@ -270,13 +275,15 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
Int::class -> convert<Float> { it.roundToInt() }
Short::class -> convert<Float> { it.roundToInt().toShort() }
BigDecimal::class -> convert<Float> { it.toBigDecimal() }
Boolean::class -> convert<Float> { it != 0.0F }
else -> null
}
BigDecimal::class -> when (toClass) {
Double::class -> convert<BigDecimal> { it.toDouble() }
Int::class -> convert<BigDecimal> { it.toInt() }
Float::class -> convert<BigDecimal> { it.toFloat() }
Long::class -> convert<BigDecimal> { it.toLong() }
Boolean::class -> convert<BigDecimal> { it != BigDecimal.ZERO }
else -> null
}
LocalDateTime::class -> when (toClass) {
Expand All @@ -285,6 +292,7 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
Long::class -> convert<LocalDateTime> { it.toInstant(defaultTimeZone).toEpochMilliseconds() }
java.time.LocalDateTime::class -> convert<LocalDateTime> { it.toJavaLocalDateTime() }
java.time.LocalDate::class -> convert<LocalDateTime> { it.date.toJavaLocalDate() }
java.time.LocalTime::class -> convert<LocalDateTime> { it.toJavaLocalDateTime().toLocalTime() }
else -> null
}
java.time.LocalDateTime::class -> when (toClass) {
Expand All @@ -293,6 +301,7 @@ internal fun createConverter(from: KType, to: KType, options: ParserOptions? = n
Instant::class -> convert<java.time.LocalDateTime> { it.toKotlinLocalDateTime().toInstant(defaultTimeZone) }
Long::class -> convert<java.time.LocalDateTime> { it.toKotlinLocalDateTime().toInstant(defaultTimeZone).toEpochMilliseconds() }
java.time.LocalDate::class -> convert<java.time.LocalDateTime> { it.toLocalDate() }
java.time.LocalTime::class -> convert<java.time.LocalDateTime> { it.toLocalTime() }
else -> null
}
LocalDate::class -> when (toClass) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ internal object Parsers : GlobalParserOptions {
inline fun <reified T : Any> stringParserWithOptions(noinline body: (ParserOptions?) -> ((String) -> T?)) =
StringParserWithFormat(typeOf<T>(), body)

private val parserToDoubleWithOptions = stringParserWithOptions { options ->
val numberFormat = NumberFormat.getInstance(options?.locale ?: Locale.getDefault())
val parser = { it: String -> it.parseDouble(numberFormat) }
parser
}

private val parsersOrder = listOf(
stringParser { it.toIntOrNull() },
stringParser { it.toLongOrNull() },
Expand Down Expand Up @@ -226,12 +232,12 @@ internal object Parsers : GlobalParserOptions {

stringParser { it.toUrlOrNull() },

stringParserWithOptions { options ->
// Double, with explicit number format or taken from current locale
parserToDoubleWithOptions,

// Double, with POSIX format
stringParser { it.parseDouble(NumberFormat.getInstance(Locale.forLanguageTag("C.UTF-8"))) },

val numberFormat = NumberFormat.getInstance(options?.locale ?: Locale.getDefault())
val parser = { it: String -> it.parseDouble(numberFormat) }
parser
},
stringParser { it.toBooleanOrNull() },
stringParser { it.toBigDecimalOrNull() },

Expand Down Expand Up @@ -266,6 +272,13 @@ internal object Parsers : GlobalParserOptions {
) else null
return parser.applyOptions(options)
}

internal fun getDoubleParser(locale: Locale? = null): (String) -> Double? {
val options = if (locale != null) ParserOptions(
locale = locale
) else null
return parserToDoubleWithOptions.applyOptions(options)
}
}

internal fun DataColumn<String?>.tryParseImpl(options: ParserOptions?): DataColumn<*> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,27 @@ class CsvTests {
assertColumnType("quality", Int::class)
}

@Test
fun `read standard CSV with floats when user has alternative locale`() {
val currentLocale = Locale.getDefault()
try {
Locale.setDefault(Locale.forLanguageTag("ru-RU"))
val df = DataFrame.readCSV(wineCsv, delimiter = ';')
val schema = df.schema()
fun assertColumnType(columnName: String, kClass: KClass<*>) {
val col = schema.columns[columnName]
col.shouldNotBeNull()
col.type.classifier shouldBe kClass
}

assertColumnType("citric acid", Double::class)
assertColumnType("alcohol", Double::class)
assertColumnType("quality", Int::class)
} finally {
Locale.setDefault(currentLocale)
}
}

@Test
fun `read with custom header`() {
val header = ('A'..'K').map { it.toString() }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
package org.jetbrains.kotlinx.dataframe.io

import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import kotlinx.datetime.LocalDateTime
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.columnOf
import org.jetbrains.kotlinx.dataframe.api.convertTo
import org.jetbrains.kotlinx.dataframe.api.convertToDouble
import org.jetbrains.kotlinx.dataframe.api.parse
import org.jetbrains.kotlinx.dataframe.api.parser
import org.jetbrains.kotlinx.dataframe.api.tryParse
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConversionException
import org.junit.Test
import java.math.BigDecimal
import java.util.Locale
import kotlin.reflect.typeOf

class ParserTests {
Expand Down Expand Up @@ -58,4 +66,88 @@ class ParserTests {
converted[0] shouldBe 1.0f
converted[1] shouldBe 0.321f
}

@Test
fun `convert to Boolean`() {
val col by columnOf(BigDecimal(1.0), BigDecimal(0.0), 0, 1, 10L, 0.0, 0.1)
col.convertTo<Boolean>().shouldBe(
DataColumn.createValueColumn("col", listOf(true, false, false, true, true, false, true), typeOf<Boolean>())
)
}

@Test
fun `converting String to Double in different locales`() {
val currentLocale = Locale.getDefault()
try {
// Test 36 behaviour combinations:

// 3 source columns
val columnDot = columnOf("12.345", "67.890")
val columnComma = columnOf("12,345", "67,890")
val columnMixed = columnOf("12.345", "67,890")
// *
// (3 locales as converting parameter + original converting)
val parsingLocaleNotDefined: Locale? = null
val parsingLocaleUsesDot: Locale = Locale.forLanguageTag("en-US")
val parsingLocaleUsesComma: Locale = Locale.forLanguageTag("ru-RU")
// *
// 3 system locales

Locale.setDefault(Locale.forLanguageTag("C.UTF-8"))

columnDot.convertTo<Double>().shouldBe(columnOf(12.345, 67.89))
columnComma.convertTo<Double>().shouldBe(columnOf(12345.0, 67890.0))
columnMixed.convertTo<Double>().shouldBe(columnOf(12.345, 67890.0))

columnDot.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67.89))
columnComma.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12345.0, 67890.0))
columnMixed.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67890.0))

columnDot.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12.345, 67.89))
columnComma.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12345.0, 67890.0))
columnMixed.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12.345, 67890.0))

shouldThrow<TypeConversionException> { columnDot.convertToDouble(parsingLocaleUsesComma) }
columnComma.convertToDouble(parsingLocaleUsesComma).shouldBe(columnOf(12.345, 67.89))
shouldThrow<TypeConversionException> { columnMixed.convertToDouble(parsingLocaleUsesComma) }

Locale.setDefault(Locale.forLanguageTag("en-US"))

columnDot.convertTo<Double>().shouldBe(columnOf(12.345, 67.89))
columnComma.convertTo<Double>().shouldBe(columnOf(12345.0, 67890.0))
columnMixed.convertTo<Double>().shouldBe(columnOf(12.345, 67890.0))

columnDot.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67.89))
columnComma.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12345.0, 67890.0))
columnMixed.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67890.0))

columnDot.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12.345, 67.89))
columnComma.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12345.0, 67890.0))
columnMixed.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12.345, 67890.0))

shouldThrow<TypeConversionException> { columnDot.convertToDouble(parsingLocaleUsesComma) }
columnComma.convertToDouble(parsingLocaleUsesComma).shouldBe(columnOf(12.345, 67.89))
shouldThrow<TypeConversionException> { columnMixed.convertToDouble(parsingLocaleUsesComma) }

Locale.setDefault(Locale.forLanguageTag("ru-RU"))

columnDot.convertTo<Double>().shouldBe(columnOf(12.345, 67.89))
columnComma.convertTo<Double>().shouldBe(columnOf(12.345, 67.89))
columnMixed.convertTo<Double>().shouldBe(columnOf(12.345, 67890.0))

columnDot.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67.89))
columnComma.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67.89))
columnMixed.convertToDouble(parsingLocaleNotDefined).shouldBe(columnOf(12.345, 67890.0))

columnDot.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12.345, 67.89))
columnComma.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12345.0, 67890.0))
columnMixed.convertToDouble(parsingLocaleUsesDot).shouldBe(columnOf(12.345, 67890.0))

shouldThrow<TypeConversionException> { columnDot.convertToDouble(parsingLocaleUsesComma) }
columnComma.convertToDouble(parsingLocaleUsesComma).shouldBe(columnOf(12.345, 67.89))
shouldThrow<TypeConversionException> { columnMixed.convertToDouble(parsingLocaleUsesComma) }
} finally {
Locale.setDefault(currentLocale)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.jetbrains.jupyter.parser.notebook.Output
import org.junit.Ignore
import org.junit.Test
import java.io.File
import java.util.Locale

class SampleNotebooksTests : DataFrameJupyterTest() {
@Test
Expand Down Expand Up @@ -39,13 +40,23 @@ class SampleNotebooksTests : DataFrameJupyterTest() {
)

@Test
fun netflix() = exampleTest(
"netflix",
replacer = CodeReplacer.byMap(
testFile("netflix", "country_codes.csv"),
testFile("netflix", "netflix_titles.csv"),
)
)
fun netflix() {
val currentLocale = Locale.getDefault()
try {
// Set explicit locale as of test data contains locale-dependent values (date for parsing)
Locale.setDefault(Locale.forLanguageTag("en-US"))

exampleTest(
"netflix",
replacer = CodeReplacer.byMap(
testFile("netflix", "country_codes.csv"),
testFile("netflix", "netflix_titles.csv"),
)
)
} finally {
Locale.setDefault(currentLocale)
}
}

@Test
@Ignore
Expand Down
Loading