Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a support for H2 modes #720

Merged
merged 10 commits into from
Jun 13, 2024
1 change: 1 addition & 0 deletions dataframe-jdbc/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies {
testImplementation(libs.mssql)
testImplementation(libs.junit)
testImplementation(libs.sl4j)
testImplementation(libs.jts)
testImplementation(libs.kotestAssertions) {
exclude("org.jetbrains.kotlin", "kotlin-stdlib-jdk8")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,67 @@ import kotlin.reflect.KType
/**
* Represents the H2 database type.
*
* This class provides methods to convert data from a ResultSet to the appropriate type for H2,
* This class provides methods to convert data from a ResultSet to the appropriate type for H2
* and to generate the corresponding column schema.
*
* NOTE: All date and timestamp related types are converted to String to avoid java.sql.* types.
* NOTE: All date and timestamp-related types are converted to String to avoid java.sql.* types.
*/
public object H2 : DbType("h2") {
public class H2(public val dialect: DbType = MySql) : DbType("h2") {
zaleslaw marked this conversation as resolved.
Show resolved Hide resolved
init {
require(dialect.javaClass.simpleName != "H2kt") { "H2 database could not be specified with H2 dialect!"}
zaleslaw marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* It contains constants related to different database modes.
*
* The mode value is used in the [extractDBTypeFromConnection] function to determine the corresponding `DbType` for the H2 database connection URL.
* For example, if the URL contains the mode value "MySQL", the H2 instance with the MySQL database type is returned.
* Otherwise, the `DbType` is determined based on the URL without the mode value.
*
* @see [extractDBTypeFromConnection]
* @see [createH2Instance]
*/
public companion object {
/** It represents the mode value "MySQL" for the H2 database. */
public const val MODE_MYSQL: String = "MySQL"

/** It represents the mode value "PostgreSQL" for the H2 database. */
public const val MODE_POSTGRESQL: String = "PostgreSQL"

/** It represents the mode value "MSSQLServer" for the H2 database. */
public const val MODE_MSSQLSERVER: String = "MSSQLServer"

/** It represents the mode value "MariaDB" for the H2 database. */
public const val MODE_MARIADB: String = "MariaDB"
}

override val driverClassName: String
get() = "org.h2.Driver"

override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
return null
return dialect.convertSqlTypeToColumnSchemaValue(tableColumnMetadata)
}

override fun isSystemTable(tableMetadata: TableMetadata): Boolean {
return tableMetadata.name.lowercase(Locale.getDefault()).contains("sys_") ||
tableMetadata.schemaName?.lowercase(Locale.getDefault())?.contains("information_schema") ?: false
val locale = Locale.getDefault()
fun String?.containsWithLowercase(substr: String) = this?.lowercase(locale)?.contains(substr) == true
val schemaName = tableMetadata.schemaName

// could be extended for other symptoms of the system tables for H2
val isH2SystemTable = schemaName.containsWithLowercase("information_schema")

return isH2SystemTable || dialect.isSystemTable(tableMetadata)
}

override fun buildTableMetadata(tables: ResultSet): TableMetadata {
return TableMetadata(
tables.getString("TABLE_NAME"),
tables.getString("TABLE_SCHEM"),
tables.getString("TABLE_CAT")
)
return dialect.buildTableMetadata(tables)
}

override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
return null
return dialect.convertSqlTypeToKType(tableColumnMetadata)
}

public override fun sqlQueryLimit(sqlQuery: String, limit: Int): String {
return dialect.sqlQueryLimit(sqlQuery, limit)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import java.sql.ResultSet
import java.util.*
import java.util.Locale
import kotlin.reflect.KType
import kotlin.reflect.full.createType

/**
* Represents the MSSQL database type.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,75 @@
package org.jetbrains.kotlinx.dataframe.io.db

import io.github.oshai.kotlinlogging.KotlinLogging
import java.sql.Connection
import java.sql.SQLException
import java.util.Locale

private val logger = KotlinLogging.logger {}

/**
* Extracts the database type from the given connection.
*
* @param [connection] the database connection.
* @return the corresponding [DbType].
* @throws [IllegalStateException] if URL information is missing in connection meta-data.
* @throws [IllegalArgumentException] if the URL specifies an unsupported database type.
* @throws [SQLException] if the URL is null.
*/
public fun extractDBTypeFromConnection(connection: Connection): DbType {
val url = connection.metaData?.url ?: throw IllegalStateException("URL information is missing in connection meta data!")
logger.info { "Processing DB type extraction for connection url: $url" }

return if (url.contains(H2().dbTypeInJdbcUrl)) {
// works only for H2 version 2
val modeQuery = "SELECT SETTING_VALUE FROM INFORMATION_SCHEMA.SETTINGS WHERE SETTING_NAME = 'MODE'"
var mode = ""
connection.createStatement().use { st ->
st.executeQuery(
modeQuery
).use { rs ->
if (rs.next()) {
mode = rs.getString("SETTING_VALUE")
logger.debug { "Fetched H2 DB mode: $mode" }
} else {
throw IllegalStateException("The information about H2 mode is not found in the H2 meta-data!")
}
}
}

// H2 doesn't support MariaDB and SQLite
when (mode.lowercase(Locale.getDefault())) {
H2.MODE_MYSQL.lowercase(Locale.getDefault()) -> H2(MySql)
H2.MODE_MSSQLSERVER.lowercase(Locale.getDefault()) -> H2(MsSql)
H2.MODE_POSTGRESQL.lowercase(Locale.getDefault()) -> H2(PostgreSql)
H2.MODE_MARIADB.lowercase(Locale.getDefault()) -> H2(MariaDb)
else -> {
val message = "Unsupported database type in the url: $url. " +
"Only MySQL, MariaDB, MSSQL and PostgreSQL are supported!"
logger.error { message }

throw IllegalArgumentException(message)
}
}
} else {
val dbType = extractDBTypeFromUrl(url)
logger.info { "Identified DB type as $dbType from url: $url" }
dbType
}
}

/**
* Extracts the database type from the given JDBC URL.
*
* @param [url] the JDBC URL.
* @return the corresponding [DbType].
* @throws RuntimeException if the url is null.
* @throws [RuntimeException] if the url is null.
*/
public fun extractDBTypeFromUrl(url: String?): DbType {
if (url != null) {
val helperH2Instance = H2()
zaleslaw marked this conversation as resolved.
Show resolved Hide resolved
return when {
H2.dbTypeInJdbcUrl in url -> H2
helperH2Instance.dbTypeInJdbcUrl in url -> createH2Instance(url)
MariaDb.dbTypeInJdbcUrl in url -> MariaDb
MySql.dbTypeInJdbcUrl in url -> MySql
Sqlite.dbTypeInJdbcUrl in url -> Sqlite
Expand All @@ -28,6 +85,37 @@ public fun extractDBTypeFromUrl(url: String?): DbType {
}
}

/**
* Creates an instance of DbType based on the provided JDBC URL.
*
* @param [url] The JDBC URL representing the database connection.
* @return The corresponding [DbType] instance.
* @throws [IllegalArgumentException] if the provided URL does not contain a valid mode.
*/
private fun createH2Instance(url: String): DbType {
val modePattern = "MODE=(.*?);".toRegex()
val matchResult = modePattern.find(url)

val mode: String = if (matchResult != null && matchResult.groupValues.size == 2) {
matchResult.groupValues[1]
} else {
throw IllegalArgumentException("The provided URL `$url` does not contain a valid mode.")
}

// H2 doesn't support MariaDB and SQLite
return when (mode.lowercase(Locale.getDefault())) {
H2.MODE_MYSQL.lowercase(Locale.getDefault()) -> H2(MySql)
H2.MODE_MSSQLSERVER.lowercase(Locale.getDefault()) -> H2(MsSql)
H2.MODE_POSTGRESQL.lowercase(Locale.getDefault()) -> H2(PostgreSql)
H2.MODE_MARIADB.lowercase(Locale.getDefault()) -> H2(MariaDb)

else -> throw IllegalArgumentException(
"Unsupported database mode: $mode. " +
"Only MySQL, MariaDB, MSSQL, PostgreSQL modes are supported!"
)
}
}

/**
* Retrieves the driver class name from the given JDBC URL.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.Infer
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl
zaleslaw marked this conversation as resolved.
Show resolved Hide resolved
import org.jetbrains.kotlinx.dataframe.io.db.DbType
import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromUrl
import org.jetbrains.kotlinx.dataframe.io.db.*
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema
import java.math.BigDecimal
Expand All @@ -26,7 +25,7 @@ import java.sql.SQLXML
import java.sql.Time
import java.sql.Timestamp
import java.sql.Types
import java.util.Date
import java.util.*
import kotlin.reflect.KType
import kotlin.reflect.full.createType
import kotlin.reflect.full.isSupertypeOf
Expand Down Expand Up @@ -138,7 +137,7 @@ public fun DataFrame.Companion.readSqlTable(
inferNullability: Boolean = true,
): AnyFrame {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val selectAllQuery = if (limit > 0) dbType.sqlQueryLimit("SELECT * FROM $tableName", limit)
else "SELECT * FROM $tableName"
Expand Down Expand Up @@ -203,8 +202,7 @@ public fun DataFrame.Companion.readSqlQuery(
"Also it should not contain any separators like `;`."
}

val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val internalSqlQuery = if (limit > 0) dbType.sqlQueryLimit(sqlQuery, limit) else sqlQuery

Expand Down Expand Up @@ -283,8 +281,7 @@ public fun DataFrame.Companion.readResultSet(
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

return readResultSet(resultSet, dbType, limit, inferNullability)
}
Expand Down Expand Up @@ -329,8 +326,7 @@ public fun DataFrame.Companion.readAllSqlTables(
inferNullability: Boolean = true,
): Map<String, AnyFrame> {
val metaData = connection.metaData
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

// exclude a system and other tables without data, but it looks like it is supported badly for many databases
val tables = metaData.getTables(catalogue, null, null, arrayOf("TABLE"))
Expand Down Expand Up @@ -390,8 +386,7 @@ public fun DataFrame.Companion.getSchemaForSqlTable(
connection: Connection,
tableName: String
): DataFrameSchema {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val sqlQuery = "SELECT * FROM $tableName"
val selectFirstRowQuery = dbType.sqlQueryLimit(sqlQuery, limit = 1)
Expand Down Expand Up @@ -432,8 +427,7 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(
* @see DriverManager.getConnection
*/
public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQuery: String): DataFrameSchema {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

connection.createStatement().use { st ->
st.executeQuery(sqlQuery).use { rs ->
Expand Down Expand Up @@ -468,8 +462,7 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbTyp
* @return the schema of the [ResultSet] as a [DataFrameSchema] object.
*/
public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, connection: Connection): DataFrameSchema {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val tableColumns = getTableColumnsMetadata(resultSet)
return buildSchemaByTableColumns(tableColumns, dbType)
Expand All @@ -495,8 +488,7 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfig
*/
public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): Map<String, DataFrameSchema> {
val metaData = connection.metaData
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val tableTypes = arrayOf("TABLE")
// exclude a system and other tables without data
Expand Down
Loading