Skip to content

Commit

Permalink
Add a support for H2 modes (#720)
Browse files Browse the repository at this point in the history
* Refactor H2 database util to inherit DbType

H2 database util has been refactored to inherit DbType, allowing for the use of different dialects. Updated the relevant test cases and added a function to correctly identify the dialect based on the URL. Added error handling for unsupported dialects.

* Split the tests for databases on two parts - local and H2 oriented

* Fixed tests

* Add jts-core library and improve code documentation

The jts-core library was added to the project dependencies, allowing for usage in the codebase. Moreover, some improvements in the documentation of the code were made. Specifically, better explanations were provided for error cases in the `extractDBTypeFromConnection` and `extractDBTypeFromUrl` functions, and extensive documentation was added for the companion object in the H2.kt file.

* Refactor code formatting across several DF-JDBC files

Various minor formatting changes have been applied to improve code readability. This includes rearranging import statements in the 'postgresTest.kt' file, removing superfluous empty lines in 'mssqlTest.kt', and adjusting white-spacing for improved consistency in 'util.kt' and 'H2.kt'.

* Update import statements across multiple files

* Refactor H2 class and extend tests

Performed a refinement in the H2 class and implemented test coverage for specific conditions. The H2 class now uses the class reference for comparison instead of the simple name, eliminating string comparison. Additionally, a test has been added to check that an exception is properly thrown when specifying an H2 database with H2 dialect. Minor import adjustments were also made in the readJdbc and mssqlTest files.
  • Loading branch information
zaleslaw authored Jun 13, 2024
1 parent cf234e7 commit c40fb04
Show file tree
Hide file tree
Showing 18 changed files with 1,689 additions and 58 deletions.
1 change: 1 addition & 0 deletions dataframe-jdbc/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies {
testImplementation(libs.mssql)
testImplementation(libs.junit)
testImplementation(libs.sl4j)
testImplementation(libs.jts)
testImplementation(libs.kotestAssertions) {
exclude("org.jetbrains.kotlin", "kotlin-stdlib-jdk8")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,67 @@ import kotlin.reflect.KType
/**
* Represents the H2 database type.
*
* This class provides methods to convert data from a ResultSet to the appropriate type for H2,
* This class provides methods to convert data from a ResultSet to the appropriate type for H2
* and to generate the corresponding column schema.
*
* NOTE: All date and timestamp related types are converted to String to avoid java.sql.* types.
* NOTE: All date and timestamp-related types are converted to String to avoid java.sql.* types.
*/
public object H2 : DbType("h2") {
public class H2(public val dialect: DbType = MySql) : DbType("h2") {
init {
require(dialect::class != H2::class) { "H2 database could not be specified with H2 dialect!" }
}

/**
* It contains constants related to different database modes.
*
* The mode value is used in the [extractDBTypeFromConnection] function to determine the corresponding `DbType` for the H2 database connection URL.
* For example, if the URL contains the mode value "MySQL", the H2 instance with the MySQL database type is returned.
* Otherwise, the `DbType` is determined based on the URL without the mode value.
*
* @see [extractDBTypeFromConnection]
* @see [createH2Instance]
*/
public companion object {
/** It represents the mode value "MySQL" for the H2 database. */
public const val MODE_MYSQL: String = "MySQL"

/** It represents the mode value "PostgreSQL" for the H2 database. */
public const val MODE_POSTGRESQL: String = "PostgreSQL"

/** It represents the mode value "MSSQLServer" for the H2 database. */
public const val MODE_MSSQLSERVER: String = "MSSQLServer"

/** It represents the mode value "MariaDB" for the H2 database. */
public const val MODE_MARIADB: String = "MariaDB"
}

override val driverClassName: String
get() = "org.h2.Driver"

override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
return null
return dialect.convertSqlTypeToColumnSchemaValue(tableColumnMetadata)
}

override fun isSystemTable(tableMetadata: TableMetadata): Boolean {
return tableMetadata.name.lowercase(Locale.getDefault()).contains("sys_") ||
tableMetadata.schemaName?.lowercase(Locale.getDefault())?.contains("information_schema") ?: false
val locale = Locale.getDefault()
fun String?.containsWithLowercase(substr: String) = this?.lowercase(locale)?.contains(substr) == true
val schemaName = tableMetadata.schemaName

// could be extended for other symptoms of the system tables for H2
val isH2SystemTable = schemaName.containsWithLowercase("information_schema")

return isH2SystemTable || dialect.isSystemTable(tableMetadata)
}

override fun buildTableMetadata(tables: ResultSet): TableMetadata {
return TableMetadata(
tables.getString("TABLE_NAME"),
tables.getString("TABLE_SCHEM"),
tables.getString("TABLE_CAT")
)
return dialect.buildTableMetadata(tables)
}

override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
return null
return dialect.convertSqlTypeToKType(tableColumnMetadata)
}

public override fun sqlQueryLimit(sqlQuery: String, limit: Int): String {
return dialect.sqlQueryLimit(sqlQuery, limit)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import java.sql.ResultSet
import java.util.*
import java.util.Locale
import kotlin.reflect.KType
import kotlin.reflect.full.createType

/**
* Represents the MSSQL database type.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,75 @@
package org.jetbrains.kotlinx.dataframe.io.db

import io.github.oshai.kotlinlogging.KotlinLogging
import java.sql.Connection
import java.sql.SQLException
import java.util.Locale

private val logger = KotlinLogging.logger {}

/**
* Extracts the database type from the given connection.
*
* @param [connection] the database connection.
* @return the corresponding [DbType].
* @throws [IllegalStateException] if URL information is missing in connection meta-data.
* @throws [IllegalArgumentException] if the URL specifies an unsupported database type.
* @throws [SQLException] if the URL is null.
*/
public fun extractDBTypeFromConnection(connection: Connection): DbType {
val url = connection.metaData?.url ?: throw IllegalStateException("URL information is missing in connection meta data!")
logger.info { "Processing DB type extraction for connection url: $url" }

return if (url.contains(H2().dbTypeInJdbcUrl)) {
// works only for H2 version 2
val modeQuery = "SELECT SETTING_VALUE FROM INFORMATION_SCHEMA.SETTINGS WHERE SETTING_NAME = 'MODE'"
var mode = ""
connection.createStatement().use { st ->
st.executeQuery(
modeQuery
).use { rs ->
if (rs.next()) {
mode = rs.getString("SETTING_VALUE")
logger.debug { "Fetched H2 DB mode: $mode" }
} else {
throw IllegalStateException("The information about H2 mode is not found in the H2 meta-data!")
}
}
}

// H2 doesn't support MariaDB and SQLite
when (mode.lowercase(Locale.getDefault())) {
H2.MODE_MYSQL.lowercase(Locale.getDefault()) -> H2(MySql)
H2.MODE_MSSQLSERVER.lowercase(Locale.getDefault()) -> H2(MsSql)
H2.MODE_POSTGRESQL.lowercase(Locale.getDefault()) -> H2(PostgreSql)
H2.MODE_MARIADB.lowercase(Locale.getDefault()) -> H2(MariaDb)
else -> {
val message = "Unsupported database type in the url: $url. " +
"Only MySQL, MariaDB, MSSQL and PostgreSQL are supported!"
logger.error { message }

throw IllegalArgumentException(message)
}
}
} else {
val dbType = extractDBTypeFromUrl(url)
logger.info { "Identified DB type as $dbType from url: $url" }
dbType
}
}

/**
* Extracts the database type from the given JDBC URL.
*
* @param [url] the JDBC URL.
* @return the corresponding [DbType].
* @throws RuntimeException if the url is null.
* @throws [RuntimeException] if the url is null.
*/
public fun extractDBTypeFromUrl(url: String?): DbType {
if (url != null) {
val helperH2Instance = H2()
return when {
H2.dbTypeInJdbcUrl in url -> H2
helperH2Instance.dbTypeInJdbcUrl in url -> createH2Instance(url)
MariaDb.dbTypeInJdbcUrl in url -> MariaDb
MySql.dbTypeInJdbcUrl in url -> MySql
Sqlite.dbTypeInJdbcUrl in url -> Sqlite
Expand All @@ -28,6 +85,37 @@ public fun extractDBTypeFromUrl(url: String?): DbType {
}
}

/**
* Creates an instance of DbType based on the provided JDBC URL.
*
* @param [url] The JDBC URL representing the database connection.
* @return The corresponding [DbType] instance.
* @throws [IllegalArgumentException] if the provided URL does not contain a valid mode.
*/
private fun createH2Instance(url: String): DbType {
val modePattern = "MODE=(.*?);".toRegex()
val matchResult = modePattern.find(url)

val mode: String = if (matchResult != null && matchResult.groupValues.size == 2) {
matchResult.groupValues[1]
} else {
throw IllegalArgumentException("The provided URL `$url` does not contain a valid mode.")
}

// H2 doesn't support MariaDB and SQLite
return when (mode.lowercase(Locale.getDefault())) {
H2.MODE_MYSQL.lowercase(Locale.getDefault()) -> H2(MySql)
H2.MODE_MSSQLSERVER.lowercase(Locale.getDefault()) -> H2(MsSql)
H2.MODE_POSTGRESQL.lowercase(Locale.getDefault()) -> H2(PostgreSql)
H2.MODE_MARIADB.lowercase(Locale.getDefault()) -> H2(MariaDb)

else -> throw IllegalArgumentException(
"Unsupported database mode: $mode. " +
"Only MySQL, MariaDB, MSSQL, PostgreSQL modes are supported!"
)
}
}

/**
* Retrieves the driver class name from the given JDBC URL.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.Infer
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl
import org.jetbrains.kotlinx.dataframe.io.db.DbType
import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromUrl
import org.jetbrains.kotlinx.dataframe.io.db.*
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema
import java.math.BigDecimal
Expand Down Expand Up @@ -138,7 +137,7 @@ public fun DataFrame.Companion.readSqlTable(
inferNullability: Boolean = true,
): AnyFrame {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val selectAllQuery = if (limit > 0) dbType.sqlQueryLimit("SELECT * FROM $tableName", limit)
else "SELECT * FROM $tableName"
Expand Down Expand Up @@ -203,8 +202,7 @@ public fun DataFrame.Companion.readSqlQuery(
"Also it should not contain any separators like `;`."
}

val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val internalSqlQuery = if (limit > 0) dbType.sqlQueryLimit(sqlQuery, limit) else sqlQuery

Expand Down Expand Up @@ -283,8 +281,7 @@ public fun DataFrame.Companion.readResultSet(
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

return readResultSet(resultSet, dbType, limit, inferNullability)
}
Expand Down Expand Up @@ -329,8 +326,7 @@ public fun DataFrame.Companion.readAllSqlTables(
inferNullability: Boolean = true,
): Map<String, AnyFrame> {
val metaData = connection.metaData
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

// exclude a system and other tables without data, but it looks like it is supported badly for many databases
val tables = metaData.getTables(catalogue, null, null, arrayOf("TABLE"))
Expand Down Expand Up @@ -390,8 +386,7 @@ public fun DataFrame.Companion.getSchemaForSqlTable(
connection: Connection,
tableName: String
): DataFrameSchema {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val sqlQuery = "SELECT * FROM $tableName"
val selectFirstRowQuery = dbType.sqlQueryLimit(sqlQuery, limit = 1)
Expand Down Expand Up @@ -432,8 +427,7 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(
* @see DriverManager.getConnection
*/
public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQuery: String): DataFrameSchema {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

connection.createStatement().use { st ->
st.executeQuery(sqlQuery).use { rs ->
Expand Down Expand Up @@ -468,8 +462,7 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbTyp
* @return the schema of the [ResultSet] as a [DataFrameSchema] object.
*/
public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, connection: Connection): DataFrameSchema {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val tableColumns = getTableColumnsMetadata(resultSet)
return buildSchemaByTableColumns(tableColumns, dbType)
Expand All @@ -495,8 +488,7 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfig
*/
public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): Map<String, DataFrameSchema> {
val metaData = connection.metaData
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
val dbType = extractDBTypeFromConnection(connection)

val tableTypes = arrayOf("TABLE")
// exclude a system and other tables without data
Expand Down
Loading

0 comments on commit c40fb04

Please sign in to comment.