Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added MS SQL support for the dataframe-jdbc module #689

Merged
merged 10 commits into from
May 13, 2024
1 change: 1 addition & 0 deletions dataframe-jdbc/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies {
testImplementation(libs.postgresql)
testImplementation(libs.mysql)
testImplementation(libs.h2db)
testImplementation(libs.mssql)
testImplementation(libs.junit)
testImplementation(libs.sl4j)
testImplementation(libs.kotestAssertions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,14 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) {
* @return The corresponding Kotlin data type, or null if no mapping is found.
*/
public abstract fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType?

/**
* Constructs a SQL query with a limit clause.
*
* @param sqlQuery The original SQL query.
* @param limit The maximum number of rows to retrieve from the query. Default is 1.
* @return A new SQL query with the limit clause added.
*/
public open fun sqlQueryLimit(sqlQuery: String, limit: Int = 1): String =
"$sqlQuery LIMIT $limit"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package org.jetbrains.kotlinx.dataframe.io.db

import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import java.sql.ResultSet
import java.util.*
import kotlin.reflect.KType
import kotlin.reflect.full.createType

/**
* Represents the MSSQL database type.
*
* This class provides methods to convert data from a ResultSet to the appropriate type for MSSQL,
* and to generate the corresponding column schema.
*/
public object MsSql : DbType("sqlserver") {
override val driverClassName: String
get() = "com.microsoft.sqlserver.jdbc.SQLServerDriver"

override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
return null
}

override fun isSystemTable(tableMetadata: TableMetadata): Boolean {
val locale = Locale.getDefault()

fun String?.containsWithLowercase(substr: String) = this?.lowercase(locale)?.contains(substr) == true

val schemaName = tableMetadata.schemaName
val tableName = tableMetadata.name
val catalogName = tableMetadata.catalogue

return schemaName.containsWithLowercase("sys") ||
schemaName.containsWithLowercase("information_schema") ||
tableName.startsWith("sys") ||
tableName.startsWith("dt") ||
tableName.containsWithLowercase("sys_config") ||
catalogName.containsWithLowercase("system") ||
catalogName.containsWithLowercase("master") ||
catalogName.containsWithLowercase("model") ||
catalogName.containsWithLowercase("msdb") ||
catalogName.containsWithLowercase("tempdb")
zaleslaw marked this conversation as resolved.
Show resolved Hide resolved
}

override fun buildTableMetadata(tables: ResultSet): TableMetadata {
return TableMetadata(
tables.getString("table_name"),
tables.getString("table_schem"),
tables.getString("table_cat")
)
}

override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
return null
}

public override fun sqlQueryLimit(sqlQuery: String, limit: Int): String {
sqlQuery.replace("SELECT", "SELECT TOP $limit", ignoreCase = true)
return sqlQuery
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ public fun extractDBTypeFromUrl(url: String?): DbType {
MySql.dbTypeInJdbcUrl in url -> MySql
Sqlite.dbTypeInJdbcUrl in url -> Sqlite
PostgreSql.dbTypeInJdbcUrl in url -> PostgreSql
MsSql.dbTypeInJdbcUrl in url -> MsSql
else -> throw IllegalArgumentException(
"Unsupported database type in the url: $url. " +
"Only H2, MariaDB, MySQL, SQLite and PostgreSQL are supported!"
"Only H2, MariaDB, MySQL, MSSQL, SQLite and PostgreSQL are supported!"
)
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,17 @@ public fun DataFrame.Companion.readSqlTable(
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
var preparedQuery = "SELECT * FROM $tableName"
if (limit > 0) preparedQuery += " LIMIT $limit"

val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)

val selectAllQuery = if (limit > 0) dbType.sqlQueryLimit("SELECT * FROM $tableName", limit)
else "SELECT * FROM $tableName"

connection.createStatement().use { st ->
logger.debug { "Connection with url:$url is established successfully." }

st.executeQuery(
preparedQuery
selectAllQuery
).use { rs ->
val tableColumns = getTableColumnsMetadata(rs)
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability)
Expand Down Expand Up @@ -206,8 +206,7 @@ public fun DataFrame.Companion.readSqlQuery(
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)

var internalSqlQuery = sqlQuery
if (limit > 0) internalSqlQuery += " LIMIT $limit"
val internalSqlQuery = if (limit > 0) dbType.sqlQueryLimit(sqlQuery, limit) else sqlQuery

logger.debug { "Executing SQL query: $internalSqlQuery" }

Expand Down Expand Up @@ -317,9 +316,11 @@ public fun DataFrame.Companion.readAllSqlTables(
val table = dbType.buildTableMetadata(tables)
if (!dbType.isSystemTable(table)) {
// we filter her second time because of specific logic with SQLite and possible issues with future databases
zaleslaw marked this conversation as resolved.
Show resolved Hide resolved
// val tableName = if (table.catalogue != null) table.catalogue + "." + table.name else table.name
val tableName = if (catalogue != null) catalogue + "." + table.name else table.name

val tableName = when {
catalogue != null && table.schemaName != null -> "$catalogue.${table.schemaName}.${table.name}"
catalogue != null && table.schemaName == null -> "$catalogue.${table.name}"
else -> table.name
}
// TODO: both cases is schema specified or not in URL
// in h2 database name is recognized as a schema name https://www.h2database.com/html/features.html#database_url
// https://stackoverflow.com/questions/20896935/spring-hibernate-h2-database-schema-not-found
Expand Down Expand Up @@ -367,11 +368,12 @@ public fun DataFrame.Companion.getSchemaForSqlTable(
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)

val preparedQuery = "SELECT * FROM $tableName LIMIT 1"
val sqlQuery = "SELECT * FROM $tableName"
val selectFirstRowQuery = dbType.sqlQueryLimit(sqlQuery, limit = 1)

connection.createStatement().use { st ->
st.executeQuery(
preparedQuery
selectFirstRowQuery
).use { rs ->
val tableColumns = getTableColumnsMetadata(rs)
return buildSchemaByTableColumns(tableColumns, dbType)
Expand Down Expand Up @@ -532,15 +534,19 @@ private fun getTableColumnsMetadata(rs: ResultSet): MutableList<TableColumnMetad
val schema: String? = rs.statement.connection.schema.takeUnless { it.isNullOrBlank() }

for (i in 1 until numberOfColumns + 1) {
val tableName = metaData.getTableName(i)
val columnName = metaData.getColumnName(i)

// this algorithm works correctly only for SQL Table and ResultSet opened on one SQL table
val columnResultSet: ResultSet =
databaseMetaData.getColumns(catalog, schema, metaData.getTableName(i), metaData.getColumnName(i))
databaseMetaData.getColumns(catalog, schema, tableName, columnName)
val isNullable = if (columnResultSet.next()) {
columnResultSet.getString("IS_NULLABLE") == "YES"
} else {
true // we assume that it's nullable by default
}

val name = manageColumnNameDuplication(columnNameCounter, metaData.getColumnName(i))
val name = manageColumnNameDuplication(columnNameCounter, columnName)
val size = metaData.getColumnDisplaySize(i)
val type = metaData.getColumnTypeName(i)
val jdbcType = metaData.getColumnType(i)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ class JdbcTest {

val saleDataSchema = dataSchemas[1]
saleDataSchema.columns.size shouldBe 3
// TODO: fix nullability
saleDataSchema.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()

val dbConfig = DatabaseConfiguration(url = URL)
Expand Down Expand Up @@ -675,6 +676,8 @@ class JdbcTest {
saleDataSchema1.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()
}

// TODO: add the same test for each particular database and refactor the scenario to the common test case
// https://github.com/Kotlin/dataframe/issues/688
@Test
fun `infer nullability`() {
// prepare tables and data
Expand Down
Loading