From 3e023bad31e9ea8de0200bbceafb6a9d64b88947 Mon Sep 17 00:00:00 2001 From: zaleslaw Date: Tue, 30 Apr 2024 13:11:03 +0200 Subject: [PATCH 01/10] Add support for MS SQL database type and associated tests Added support for MS SQL database type in the util.kt file and created a new file for MS SQL configuration. Additionally, implemented test cases for new support in mssqlTest.kt. --- dataframe-jdbc/build.gradle.kts | 2 + .../kotlinx/dataframe/io/db/MsSql.kt | 39 ++ .../jetbrains/kotlinx/dataframe/io/db/util.kt | 1 + .../kotlinx/dataframe/io/mssqlTest.kt | 416 ++++++++++++++++++ 4 files changed, 458 insertions(+) create mode 100644 dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt create mode 100644 dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt diff --git a/dataframe-jdbc/build.gradle.kts b/dataframe-jdbc/build.gradle.kts index 19446166c..aa7d4302a 100644 --- a/dataframe-jdbc/build.gradle.kts +++ b/dataframe-jdbc/build.gradle.kts @@ -25,6 +25,8 @@ dependencies { testImplementation(libs.postgresql) testImplementation(libs.mysql) testImplementation(libs.h2db) + // TODO + testImplementation ("com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre11") testImplementation(libs.junit) testImplementation(libs.sl4j) testImplementation(libs.kotestAssertions) { diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt new file mode 100644 index 000000000..c7781e00e --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt @@ -0,0 +1,39 @@ +package org.jetbrains.kotlinx.dataframe.io.db + +import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata +import org.jetbrains.kotlinx.dataframe.io.TableMetadata +import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema +import java.sql.ResultSet +import kotlin.reflect.KType +import kotlin.reflect.full.createType + +/** + * Represents the MariaDb database type. + * + * This class provides methods to convert data from a ResultSet to the appropriate type for MariaDb, + * and to generate the corresponding column schema. + */ +public object MsSql : DbType("sqlserver") { + override val driverClassName: String + get() = "com.microsoft.sqlserver.jdbc.SQLServerDriver" + + override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? { + return null + } + + override fun isSystemTable(tableMetadata: TableMetadata): Boolean { + return MySql.isSystemTable(tableMetadata) + } + + override fun buildTableMetadata(tables: ResultSet): TableMetadata { + return TableMetadata( + tables.getString("table_name"), + tables.getString("table_schem"), + tables.getString("table_cat") + ) + } + + override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? { + return null + } +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt index 1ea06bc1e..1c3e9c238 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt @@ -17,6 +17,7 @@ public fun extractDBTypeFromUrl(url: String?): DbType { MySql.dbTypeInJdbcUrl in url -> MySql Sqlite.dbTypeInJdbcUrl in url -> Sqlite PostgreSql.dbTypeInJdbcUrl in url -> PostgreSql + MsSql.dbTypeInJdbcUrl in url -> MsSql else -> throw IllegalArgumentException( "Unsupported database type in the url: $url. " + "Only H2, MariaDB, MySQL, SQLite and PostgreSQL are supported!" diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt new file mode 100644 index 000000000..906ab06c6 --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt @@ -0,0 +1,416 @@ +package org.jetbrains.kotlinx.dataframe.io + +import io.kotest.matchers.shouldBe +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.add +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.filter +import org.jetbrains.kotlinx.dataframe.api.select +import org.junit.AfterClass +import org.junit.BeforeClass +import org.junit.Ignore +import org.junit.Test +import java.math.BigDecimal +import java.sql.Connection +import java.sql.DriverManager +import java.sql.SQLException +import java.util.* +import kotlin.reflect.typeOf + +private const val URL = "jdbc:sqlserver://localhost:1433;encrypt=true;trustServerCertificate=true" +private const val USER_NAME = "root" +private const val PASSWORD = "pass" +private const val TEST_DATABASE_NAME = "testKDFdatabase" + +@DataSchema +interface Table1MSSSQL { + val id: Int + val bitCol: Boolean + val tinyintCol: Int + val smallintCol: Short? + val mediumintCol: Int + val mediumintUnsignedCol: Int + val integerCol: Int + val intCol: Int + val integerUnsignedCol: Long + val bigintCol: Long + val floatCol: Float + val doubleCol: Double + val decimalCol: BigDecimal + val dateCol: String + val datetimeCol: String + val timestampCol: String + val timeCol: String + val yearCol: String + val varcharCol: String + val charCol: String + val binaryCol: ByteArray + val varbinaryCol: ByteArray + val tinyblobCol: ByteArray + val blobCol: ByteArray + val mediumblobCol: ByteArray + val longblobCol: ByteArray + val textCol: String + val mediumtextCol: String + val longtextCol: String + val enumCol: String + val setCol: Char + val jsonCol: String +} + +class MSSQLTest { + companion object { + private lateinit var connection: Connection + + @BeforeClass + @JvmStatic + fun setUpClass() { + connection = DriverManager.getConnection(URL, USER_NAME, PASSWORD) + + connection.createStatement().use { st -> + // Drop the test database if it exists + // val dropDatabaseQuery = "DROP DATABASE IF EXISTS $TEST_DATABASE_NAME" + // st.executeUpdate(dropDatabaseQuery) + + // Create the test database + // val createDatabaseQuery = "CREATE DATABASE $TEST_DATABASE_NAME" + // st.executeUpdate(createDatabaseQuery) + + // Use the newly created database + val useDatabaseQuery = "USE $TEST_DATABASE_NAME" + st.executeUpdate(useDatabaseQuery) + } + + // connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table1") } + // connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table2") } +// TODO: create if not exists is too complex https://forums.sqlteam.com/t/create-table-if-not-exists/22596/4 + @Language("SQL") + val createTableQuery = """ + CREATE TABLE Table1 ( + bigintColumn BIGINT, + binaryColumn BINARY(50), + bitColumn BIT, + charColumn CHAR(10), + dateColumn DATE, + datetime3Column DATETIME2(3), + datetime2Column DATETIME2, + datetimeoffset2Column DATETIMEOFFSET(2), + decimalColumn DECIMAL(10,2), + floatColumn FLOAT, + imageColumn IMAGE, + intColumn INT, + moneyColumn MONEY, + ncharColumn NCHAR(10), + ntextColumn NTEXT, + numericColumn NUMERIC(10,2), + nvarcharColumn NVARCHAR(50), + nvarcharMaxColumn NVARCHAR(MAX), + realColumn REAL, + smalldatetimeColumn SMALLDATETIME, + smallintColumn SMALLINT, + smallmoneyColumn SMALLMONEY, + textColumn TEXT, + timeColumn TIME, + timestampColumn DATETIME2, + tinyintColumn TINYINT, + uniqueidentifierColumn UNIQUEIDENTIFIER, + varbinaryColumn VARBINARY(50), + varbinaryMaxColumn VARBINARY(MAX), + varcharColumn VARCHAR(50), + varcharMaxColumn VARCHAR(MAX), + xmlColumn XML, + sqlvariantColumn SQL_VARIANT, + geometryColumn GEOMETRY, + geographyColumn GEOGRAPHY +); + """ + + // TODO: timestamp column could be removed +/* connection.createStatement().execute( + createTableQuery.trimIndent() + )*/ +/* + @Language("SQL") + val createTableQuery2 = """ + CREATE TABLE IF NOT EXISTS table2 ( + id INT AUTO_INCREMENT PRIMARY KEY, + bitCol BIT, + tinyintCol TINYINT, + smallintCol SMALLINT, + mediumintCol MEDIUMINT, + mediumintUnsignedCol MEDIUMINT UNSIGNED, + integerCol INTEGER, + intCol INT, + integerUnsignedCol INTEGER UNSIGNED, + bigintCol BIGINT, + floatCol FLOAT, + doubleCol DOUBLE, + decimalCol DECIMAL, + dateCol DATE, + datetimeCol DATETIME, + timestampCol TIMESTAMP, + timeCol TIME, + yearCol YEAR, + varcharCol VARCHAR(255), + charCol CHAR(10), + binaryCol BINARY(64), + varbinaryCol VARBINARY(128), + tinyblobCol TINYBLOB, + blobCol BLOB, + mediumblobCol MEDIUMBLOB, + longblobCol LONGBLOB, + textCol TEXT, + mediumtextCol MEDIUMTEXT, + longtextCol LONGTEXT, + enumCol ENUM('Value1', 'Value2', 'Value3'), + setCol SET('Option1', 'Option2', 'Option3') + ) + """ + connection.createStatement().execute( + createTableQuery2.trimIndent() + ) +*/ + @Language("SQL") + val insertData1 = """ + INSERT INTO Table1 ( + bigintColumn, binaryColumn, bitColumn, charColumn, dateColumn, datetime3Column, datetime2Column, + datetimeoffset2Column, decimalColumn, floatColumn, imageColumn, intColumn, moneyColumn, ncharColumn, + ntextColumn, numericColumn, nvarcharColumn, nvarcharMaxColumn, realColumn, smalldatetimeColumn, + smallintColumn, smallmoneyColumn, textColumn, timeColumn, timestampColumn, tinyintColumn, + uniqueidentifierColumn, varbinaryColumn, varbinaryMaxColumn, varcharColumn, varcharMaxColumn, + xmlColumn, sqlvariantColumn, geometryColumn, geographyColumn + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +""".trimIndent() + + + /* @Language("SQL") + val insertData2 = """ + INSERT INTO table2 ( + bitCol, tinyintCol, smallintCol, mediumintCol, mediumintUnsignedCol, integerCol, intCol, + integerUnsignedCol, bigintCol, floatCol, doubleCol, decimalCol, dateCol, datetimeCol, timestampCol, + timeCol, yearCol, varcharCol, charCol, binaryCol, varbinaryCol, tinyblobCol, blobCol, + mediumblobCol, longblobCol, textCol, mediumtextCol, longtextCol, enumCol, setCol + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """.trimIndent()*/ + + connection.prepareStatement(insertData1).use { st -> + for (i in 1..5) { + st.setLong(1, 123456789012345L) // bigintColumn + st.setBytes(2, byteArrayOf(0x01, 0x23, 0x45, 0x67, 0x67, 0x67, 0x67, 0x67)) // binaryColumn + st.setBoolean(3, true) // bitColumn + st.setString(4, "Sample") // charColumn + st.setDate(5, java.sql.Date(System.currentTimeMillis())) // dateColumn + st.setTimestamp(6, java.sql.Timestamp(System.currentTimeMillis())) // datetime3Column + st.setTimestamp(7, java.sql.Timestamp(System.currentTimeMillis())) // datetime2Column + st.setTimestamp(8, java.sql.Timestamp(System.currentTimeMillis())) // datetimeoffset2Column + st.setBigDecimal(9, BigDecimal("12345.67")) // decimalColumn + st.setFloat(10, 123.45f) // floatColumn + st.setNull(11, java.sql.Types.NULL) // imageColumn (assuming nullable) + st.setInt(12, 123456) // intColumn + st.setBigDecimal(13, BigDecimal("123.45")) // moneyColumn + st.setString(14, "Sample") // ncharColumn + st.setString(15, "Sample text") // ntextColumn + st.setBigDecimal(16, BigDecimal("1234.56")) // numericColumn + st.setString(17, "Sample") // nvarcharColumn + st.setString(18, "Sample text") // nvarcharMaxColumn + st.setFloat(19, 123.45f) // realColumn + st.setTimestamp(20, java.sql.Timestamp(System.currentTimeMillis())) // smalldatetimeColumn + st.setInt(21, 123) // smallintColumn + st.setBigDecimal(22, BigDecimal("123.45")) // smallmoneyColumn + st.setString(23, "Sample text") // textColumn + st.setTime(24, java.sql.Time(System.currentTimeMillis())) // timeColumn + st.setTimestamp(25, java.sql.Timestamp(System.currentTimeMillis())) // timestampColumn + st.setInt(26, 123) // tinyintColumn + //st.setObject(27, null) // udtColumn (assuming nullable) + st.setObject(27, UUID.randomUUID()) // uniqueidentifierColumn + st.setBytes(28, byteArrayOf(0x01, 0x23, 0x45, 0x67, 0x67, 0x67, 0x67, 0x67)) // varbinaryColumn + st.setBytes(29, byteArrayOf(0x01, 0x23, 0x45, 0x67, 0x67, 0x67, 0x67, 0x67)) // varbinaryMaxColumn + st.setString(30, "Sample") // varcharColumn + st.setString(31, "Sample text") // varcharMaxColumn + st.setString(32, "Sample") // xmlColumn + st.setString(33, "SQL_VARIANT") // sqlvariantColumn + st.setBytes(34, + byteArrayOf(0xE6.toByte(), 0x10, 0x00, 0x00, 0x01, 0x0C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x44, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x05, 0x4C, 0x0)) // geometryColumn + st.setString(35, "POINT(1 1)") // geographyColumn + // st.executeUpdate() + } + } + + /* connection.prepareStatement(insertData2).use { st -> + // Insert data into table2 + for (i in 1..3) { + st.setBoolean(1, false) + st.setByte(2, (i * 2).toByte()) + st.setShort(3, (i * 20).toShort()) + st.setInt(4, i * 200) + st.setInt(5, i * 200) + st.setInt(6, i * 200) + st.setInt(7, i * 200) + st.setInt(8, i * 200) + st.setInt(9, i * 200) + st.setFloat(10, i * 20.0f) + st.setDouble(11, i * 20.0) + st.setBigDecimal(12, BigDecimal(i * 20)) + st.setDate(13, java.sql.Date(System.currentTimeMillis())) + st.setTimestamp(14, java.sql.Timestamp(System.currentTimeMillis())) + st.setTimestamp(15, java.sql.Timestamp(System.currentTimeMillis())) + st.setTime(16, java.sql.Time(System.currentTimeMillis())) + st.setInt(17, 2023) + st.setString(18, "varcharValue$i") + st.setString(19, "charValue$i") + st.setBytes(20, "binaryValue".toByteArray()) + st.setBytes(21, "varbinaryValue".toByteArray()) + st.setBytes(22, "tinyblobValue".toByteArray()) + st.setBytes(23, "blobValue".toByteArray()) + st.setBytes(24, "mediumblobValue".toByteArray()) + st.setBytes(25, "longblobValue".toByteArray()) + st.setString(26, null) + st.setString(27, null) + st.setString(28, "longtextValue$i") + st.setString(29, "Value$i") + st.setString(30, "Option$i") + st.executeUpdate() + } + }*/ + } + + @AfterClass + @JvmStatic + fun tearDownClass() { + /*try { + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table1") } + connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table2") } + connection.createStatement().use { st -> st.execute("DROP DATABASE IF EXISTS $TEST_DATABASE_NAME") } + connection.close() + } catch (e: SQLException) { + e.printStackTrace() + }*/ + } + } + + @Test + fun `basic test for reading sql tables`() { + val df1 = DataFrame.readSqlTable(connection, "table1").cast() + val result = df1.filter { it[Table1MSSSQL::id] == 1 } + result[0][26] shouldBe "textValue1" + + /*val schema = DataFrame.getSchemaForSqlTable(connection, "table1") + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["textCol"]!!.type shouldBe typeOf() + + val df2 = DataFrame.readSqlTable(connection, "table2").cast() + val result2 = df2.filter { it[Table2MSSQL::id] == 1 } + result2[0][26] shouldBe null + + val schema2 = DataFrame.getSchemaForSqlTable(connection, "table2") + schema2.columns["id"]!!.type shouldBe typeOf() + schema2.columns["textCol"]!!.type shouldBe typeOf()*/ + } + + @Test + fun `read from sql query`() { + /* @Language("SQL") + val sqlQuery = """ + SELECT + t1.id, + t1.enumCol, + t2.setCol + FROM table1 t1 + JOIN table2 t2 ON t1.id = t2.id + """.trimIndent() + + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() + val result = df.filter { it[Table3MSSQL::id] == 1 } + result[0][2] shouldBe "Option1" + + val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery = sqlQuery) + schema.columns["id"]!!.type shouldBe typeOf() + schema.columns["enumCol"]!!.type shouldBe typeOf() + schema.columns["setCol"]!!.type shouldBe typeOf()*/ + } + + @Test + fun `read from all tables`() { + /* val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 1000) + + val table1Df = dataframes[0].cast() + + table1Df.rowsCount() shouldBe 3 + table1Df.filter { it[Table1MSSSQL::integerCol] > 100 }.rowsCount() shouldBe 2 + table1Df[0][11] shouldBe 10.0 + table1Df[0][26] shouldBe "textValue1" + table1Df[0][31] shouldBe JSON_STRING // TODO: https://github.com/Kotlin/dataframe/issues/462 + + val table2Df = dataframes[1].cast() + + table2Df.rowsCount() shouldBe 3 + table2Df.filter { it[Table2MSSQL::integerCol] != null && it[Table2MSSQL::integerCol]!! > 400 } + .rowsCount() shouldBe 1 + table2Df[0][11] shouldBe 20.0 + table2Df[0][26] shouldBe null*/ + } + + @Test + fun `reading numeric types`() { + /* val df1 = DataFrame.readSqlTable(connection, "table1").cast() + + val result = df1.select("tinyintCol") + .add("tinyintCol2") { it[Table1MSSSQL::tinyintCol] } + + result[0][1] shouldBe 1 + + val result1 = df1.select("smallintCol") + .add("smallintCol2") { it[Table1MSSSQL::smallintCol] } + + result1[0][1] shouldBe 10 + + val result2 = df1.select("mediumintCol") + .add("mediumintCol2") { it[Table1MSSSQL::mediumintCol] } + + result2[0][1] shouldBe 100 + + val result3 = df1.select("mediumintUnsignedCol") + .add("mediumintUnsignedCol2") { it[Table1MSSSQL::mediumintUnsignedCol] } + + result3[0][1] shouldBe 100 + + val result4 = df1.select("integerUnsignedCol") + .add("integerUnsignedCol2") { it[Table1MSSSQL::integerUnsignedCol] } + + result4[0][1] shouldBe 100L + + val result5 = df1.select("bigintCol") + .add("bigintCol2") { it[Table1MSSSQL::bigintCol] } + + result5[0][1] shouldBe 100 + + val result6 = df1.select("floatCol") + .add("floatCol2") { it[Table1MSSSQL::floatCol] } + + result6[0][1] shouldBe 10.0f + + val result7 = df1.select("doubleCol") + .add("doubleCol2") { it[Table1MSSSQL::doubleCol] } + + result7[0][1] shouldBe 10.0 + + val result8 = df1.select("decimalCol") + .add("decimalCol2") { it[Table1MSSSQL::decimalCol] } + + result8[0][1] shouldBe BigDecimal("10") + + val schema = DataFrame.getSchemaForSqlTable(connection, "table1") + + schema.columns["tinyintCol"]!!.type shouldBe typeOf() + schema.columns["smallintCol"]!!.type shouldBe typeOf() + schema.columns["mediumintCol"]!!.type shouldBe typeOf() + schema.columns["mediumintUnsignedCol"]!!.type shouldBe typeOf() + schema.columns["integerUnsignedCol"]!!.type shouldBe typeOf() + schema.columns["bigintCol"]!!.type shouldBe typeOf() + schema.columns["floatCol"]!!.type shouldBe typeOf() + schema.columns["doubleCol"]!!.type shouldBe typeOf() + schema.columns["decimalCol"]!!.type shouldBe typeOf()*/ + } +} From 849725f48d5954b5df4cd6cda704da4114afa02f Mon Sep 17 00:00:00 2001 From: zaleslaw Date: Tue, 30 Apr 2024 16:56:40 +0200 Subject: [PATCH 02/10] Add sqlQueryLimitOne method to DbType and update test cases Implemented sqlQueryLimitOne method in DbType companion object. This method generates a SQL query that selects one record from a given table. Also, updated the unit tests to accommodate these modifications. --- .../kotlinx/dataframe/io/db/DbType.kt | 2 + .../kotlinx/dataframe/io/db/MsSql.kt | 2 + .../kotlinx/dataframe/io/readJdbc.kt | 4 +- .../kotlinx/dataframe/io/mssqlTest.kt | 457 +++++++----------- 4 files changed, 168 insertions(+), 297 deletions(-) diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt index d025a34b8..ee86d5514 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt @@ -50,4 +50,6 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) { * @return The corresponding Kotlin data type, or null if no mapping is found. */ public abstract fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? + + public open fun sqlQueryLimitOne(tableName: String): String = "SELECT * FROM $tableName LIMIT 1" } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt index c7781e00e..2ba9c847e 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt @@ -36,4 +36,6 @@ public object MsSql : DbType("sqlserver") { override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? { return null } + + public override fun sqlQueryLimitOne(tableName: String): String = "SELECT TOP 1 * FROM $tableName" } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index 2b6d0e1b6..a36b43e14 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -367,11 +367,11 @@ public fun DataFrame.Companion.getSchemaForSqlTable( val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) - val preparedQuery = "SELECT * FROM $tableName LIMIT 1" + val selectFirstRowQuery = dbType.sqlQueryLimitOne(tableName) connection.createStatement().use { st -> st.executeQuery( - preparedQuery + selectFirstRowQuery ).use { rs -> val tableColumns = getTableColumnsMetadata(rs) return buildSchemaByTableColumns(tableColumns, dbType) diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt index 906ab06c6..3370070ca 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt @@ -4,10 +4,7 @@ import io.kotest.matchers.shouldBe import org.intellij.lang.annotations.Language import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.annotations.DataSchema -import org.jetbrains.kotlinx.dataframe.api.add -import org.jetbrains.kotlinx.dataframe.api.cast -import org.jetbrains.kotlinx.dataframe.api.filter -import org.jetbrains.kotlinx.dataframe.api.select +import org.jetbrains.kotlinx.dataframe.api.* import org.junit.AfterClass import org.junit.BeforeClass import org.junit.Ignore @@ -27,37 +24,40 @@ private const val TEST_DATABASE_NAME = "testKDFdatabase" @DataSchema interface Table1MSSSQL { val id: Int - val bitCol: Boolean - val tinyintCol: Int - val smallintCol: Short? - val mediumintCol: Int - val mediumintUnsignedCol: Int - val integerCol: Int - val intCol: Int - val integerUnsignedCol: Long - val bigintCol: Long - val floatCol: Float - val doubleCol: Double - val decimalCol: BigDecimal - val dateCol: String - val datetimeCol: String - val timestampCol: String - val timeCol: String - val yearCol: String - val varcharCol: String - val charCol: String - val binaryCol: ByteArray - val varbinaryCol: ByteArray - val tinyblobCol: ByteArray - val blobCol: ByteArray - val mediumblobCol: ByteArray - val longblobCol: ByteArray - val textCol: String - val mediumtextCol: String - val longtextCol: String - val enumCol: String - val setCol: Char - val jsonCol: String + val bigintColumn: Long + val binaryColumn: ByteArray + val bitColumn: Boolean + val charColumn: Char + val dateColumn: Date + val datetime3Column: java.sql.Timestamp + val datetime2Column: java.sql.Timestamp + val datetimeoffset2Column: String + val decimalColumn: BigDecimal + val floatColumn: Double + val imageColumn: ByteArray? + val intColumn: Int + val moneyColumn: BigDecimal + val ncharColumn: Char + val ntextColumn: String + val numericColumn: BigDecimal + val nvarcharColumn: String + val nvarcharMaxColumn: String + val realColumn: Float + val smalldatetimeColumn: java.sql.Timestamp + val smallintColumn: Int + val smallmoneyColumn: BigDecimal + val timeColumn: java.sql.Time + val timestampColumn: java.sql.Timestamp + val tinyintColumn: Int + val uniqueidentifierColumn: Char + val varbinaryColumn: ByteArray + val varbinaryMaxColumn: ByteArray + val varcharColumn: String + val varcharMaxColumn: String + val xmlColumn: String + val sqlvariantColumn: String + val geometryColumn: String + val geographyColumn: String } class MSSQLTest { @@ -71,129 +71,76 @@ class MSSQLTest { connection.createStatement().use { st -> // Drop the test database if it exists - // val dropDatabaseQuery = "DROP DATABASE IF EXISTS $TEST_DATABASE_NAME" - // st.executeUpdate(dropDatabaseQuery) + val dropDatabaseQuery = "IF DB_ID('$TEST_DATABASE_NAME') IS NOT NULL\n" + + "DROP DATABASE $TEST_DATABASE_NAME" + st.executeUpdate(dropDatabaseQuery) // Create the test database - // val createDatabaseQuery = "CREATE DATABASE $TEST_DATABASE_NAME" - // st.executeUpdate(createDatabaseQuery) + val createDatabaseQuery = "CREATE DATABASE $TEST_DATABASE_NAME" + st.executeUpdate(createDatabaseQuery) // Use the newly created database val useDatabaseQuery = "USE $TEST_DATABASE_NAME" st.executeUpdate(useDatabaseQuery) } - // connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table1") } - // connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table2") } -// TODO: create if not exists is too complex https://forums.sqlteam.com/t/create-table-if-not-exists/22596/4 @Language("SQL") val createTableQuery = """ - CREATE TABLE Table1 ( - bigintColumn BIGINT, - binaryColumn BINARY(50), - bitColumn BIT, - charColumn CHAR(10), - dateColumn DATE, - datetime3Column DATETIME2(3), - datetime2Column DATETIME2, - datetimeoffset2Column DATETIMEOFFSET(2), - decimalColumn DECIMAL(10,2), - floatColumn FLOAT, - imageColumn IMAGE, - intColumn INT, - moneyColumn MONEY, - ncharColumn NCHAR(10), - ntextColumn NTEXT, - numericColumn NUMERIC(10,2), - nvarcharColumn NVARCHAR(50), - nvarcharMaxColumn NVARCHAR(MAX), - realColumn REAL, - smalldatetimeColumn SMALLDATETIME, - smallintColumn SMALLINT, - smallmoneyColumn SMALLMONEY, - textColumn TEXT, - timeColumn TIME, - timestampColumn DATETIME2, - tinyintColumn TINYINT, - uniqueidentifierColumn UNIQUEIDENTIFIER, - varbinaryColumn VARBINARY(50), - varbinaryMaxColumn VARBINARY(MAX), - varcharColumn VARCHAR(50), - varcharMaxColumn VARCHAR(MAX), - xmlColumn XML, - sqlvariantColumn SQL_VARIANT, - geometryColumn GEOMETRY, - geographyColumn GEOGRAPHY -); - """ - - // TODO: timestamp column could be removed -/* connection.createStatement().execute( - createTableQuery.trimIndent() - )*/ -/* - @Language("SQL") - val createTableQuery2 = """ - CREATE TABLE IF NOT EXISTS table2 ( - id INT AUTO_INCREMENT PRIMARY KEY, - bitCol BIT, - tinyintCol TINYINT, - smallintCol SMALLINT, - mediumintCol MEDIUMINT, - mediumintUnsignedCol MEDIUMINT UNSIGNED, - integerCol INTEGER, - intCol INT, - integerUnsignedCol INTEGER UNSIGNED, - bigintCol BIGINT, - floatCol FLOAT, - doubleCol DOUBLE, - decimalCol DECIMAL, - dateCol DATE, - datetimeCol DATETIME, - timestampCol TIMESTAMP, - timeCol TIME, - yearCol YEAR, - varcharCol VARCHAR(255), - charCol CHAR(10), - binaryCol BINARY(64), - varbinaryCol VARBINARY(128), - tinyblobCol TINYBLOB, - blobCol BLOB, - mediumblobCol MEDIUMBLOB, - longblobCol LONGBLOB, - textCol TEXT, - mediumtextCol MEDIUMTEXT, - longtextCol LONGTEXT, - enumCol ENUM('Value1', 'Value2', 'Value3'), - setCol SET('Option1', 'Option2', 'Option3') - ) + CREATE TABLE Table1 ( + id INT NOT NULL IDENTITY PRIMARY KEY, + bigintColumn BIGINT, + binaryColumn BINARY(50), + bitColumn BIT, + charColumn CHAR(10), + dateColumn DATE, + datetime3Column DATETIME2(3), + datetime2Column DATETIME2, + datetimeoffset2Column DATETIMEOFFSET(2), + decimalColumn DECIMAL(10,2), + floatColumn FLOAT, + imageColumn IMAGE, + intColumn INT, + moneyColumn MONEY, + ncharColumn NCHAR(10), + ntextColumn NTEXT, + numericColumn NUMERIC(10,2), + nvarcharColumn NVARCHAR(50), + nvarcharMaxColumn NVARCHAR(MAX), + realColumn REAL, + smalldatetimeColumn SMALLDATETIME, + smallintColumn SMALLINT, + smallmoneyColumn SMALLMONEY, + textColumn TEXT, + timeColumn TIME, + timestampColumn DATETIME2, + tinyintColumn TINYINT, + uniqueidentifierColumn UNIQUEIDENTIFIER, + varbinaryColumn VARBINARY(50), + varbinaryMaxColumn VARBINARY(MAX), + varcharColumn VARCHAR(50), + varcharMaxColumn VARCHAR(MAX), + xmlColumn XML, + sqlvariantColumn SQL_VARIANT, + geometryColumn GEOMETRY, + geographyColumn GEOGRAPHY + ); """ + connection.createStatement().execute( - createTableQuery2.trimIndent() + createTableQuery.trimIndent() ) -*/ + @Language("SQL") val insertData1 = """ - INSERT INTO Table1 ( - bigintColumn, binaryColumn, bitColumn, charColumn, dateColumn, datetime3Column, datetime2Column, - datetimeoffset2Column, decimalColumn, floatColumn, imageColumn, intColumn, moneyColumn, ncharColumn, - ntextColumn, numericColumn, nvarcharColumn, nvarcharMaxColumn, realColumn, smalldatetimeColumn, - smallintColumn, smallmoneyColumn, textColumn, timeColumn, timestampColumn, tinyintColumn, - uniqueidentifierColumn, varbinaryColumn, varbinaryMaxColumn, varcharColumn, varcharMaxColumn, - xmlColumn, sqlvariantColumn, geometryColumn, geographyColumn - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) -""".trimIndent() - - - /* @Language("SQL") - val insertData2 = """ - INSERT INTO table2 ( - bitCol, tinyintCol, smallintCol, mediumintCol, mediumintUnsignedCol, integerCol, intCol, - integerUnsignedCol, bigintCol, floatCol, doubleCol, decimalCol, dateCol, datetimeCol, timestampCol, - timeCol, yearCol, varcharCol, charCol, binaryCol, varbinaryCol, tinyblobCol, blobCol, - mediumblobCol, longblobCol, textCol, mediumtextCol, longtextCol, enumCol, setCol - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """.trimIndent()*/ + INSERT INTO Table1 ( + bigintColumn, binaryColumn, bitColumn, charColumn, dateColumn, datetime3Column, datetime2Column, + datetimeoffset2Column, decimalColumn, floatColumn, imageColumn, intColumn, moneyColumn, ncharColumn, + ntextColumn, numericColumn, nvarcharColumn, nvarcharMaxColumn, realColumn, smalldatetimeColumn, + smallintColumn, smallmoneyColumn, textColumn, timeColumn, timestampColumn, tinyintColumn, + uniqueidentifierColumn, varbinaryColumn, varbinaryMaxColumn, varcharColumn, varcharMaxColumn, + xmlColumn, sqlvariantColumn, geometryColumn, geographyColumn + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """.trimIndent() connection.prepareStatement(insertData1).use { st -> for (i in 1..5) { @@ -211,15 +158,15 @@ class MSSQLTest { st.setInt(12, 123456) // intColumn st.setBigDecimal(13, BigDecimal("123.45")) // moneyColumn st.setString(14, "Sample") // ncharColumn - st.setString(15, "Sample text") // ntextColumn + st.setString(15, "Sample$i text") // ntextColumn st.setBigDecimal(16, BigDecimal("1234.56")) // numericColumn st.setString(17, "Sample") // nvarcharColumn - st.setString(18, "Sample text") // nvarcharMaxColumn + st.setString(18, "Sample$i text") // nvarcharMaxColumn st.setFloat(19, 123.45f) // realColumn st.setTimestamp(20, java.sql.Timestamp(System.currentTimeMillis())) // smalldatetimeColumn st.setInt(21, 123) // smallintColumn st.setBigDecimal(22, BigDecimal("123.45")) // smallmoneyColumn - st.setString(23, "Sample text") // textColumn + st.setString(23, "Sample$i text") // textColumn st.setTime(24, java.sql.Time(System.currentTimeMillis())) // timeColumn st.setTimestamp(25, java.sql.Timestamp(System.currentTimeMillis())) // timestampColumn st.setInt(26, 123) // tinyintColumn @@ -227,190 +174,110 @@ class MSSQLTest { st.setObject(27, UUID.randomUUID()) // uniqueidentifierColumn st.setBytes(28, byteArrayOf(0x01, 0x23, 0x45, 0x67, 0x67, 0x67, 0x67, 0x67)) // varbinaryColumn st.setBytes(29, byteArrayOf(0x01, 0x23, 0x45, 0x67, 0x67, 0x67, 0x67, 0x67)) // varbinaryMaxColumn - st.setString(30, "Sample") // varcharColumn - st.setString(31, "Sample text") // varcharMaxColumn - st.setString(32, "Sample") // xmlColumn + st.setString(30, "Sample$i") // varcharColumn + st.setString(31, "Sample$i text") // varcharMaxColumn + st.setString(32, "Sample$i") // xmlColumn st.setString(33, "SQL_VARIANT") // sqlvariantColumn - st.setBytes(34, - byteArrayOf(0xE6.toByte(), 0x10, 0x00, 0x00, 0x01, 0x0C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x44, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x05, 0x4C, 0x0)) // geometryColumn + st.setBytes( + 34, byteArrayOf( + 0xE6.toByte(), 0x10, 0x00, 0x00, 0x01, 0x0C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x44, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x05, 0x4C, 0x0 + ) + ) // geometryColumn st.setString(35, "POINT(1 1)") // geographyColumn - // st.executeUpdate() - } - } - - /* connection.prepareStatement(insertData2).use { st -> - // Insert data into table2 - for (i in 1..3) { - st.setBoolean(1, false) - st.setByte(2, (i * 2).toByte()) - st.setShort(3, (i * 20).toShort()) - st.setInt(4, i * 200) - st.setInt(5, i * 200) - st.setInt(6, i * 200) - st.setInt(7, i * 200) - st.setInt(8, i * 200) - st.setInt(9, i * 200) - st.setFloat(10, i * 20.0f) - st.setDouble(11, i * 20.0) - st.setBigDecimal(12, BigDecimal(i * 20)) - st.setDate(13, java.sql.Date(System.currentTimeMillis())) - st.setTimestamp(14, java.sql.Timestamp(System.currentTimeMillis())) - st.setTimestamp(15, java.sql.Timestamp(System.currentTimeMillis())) - st.setTime(16, java.sql.Time(System.currentTimeMillis())) - st.setInt(17, 2023) - st.setString(18, "varcharValue$i") - st.setString(19, "charValue$i") - st.setBytes(20, "binaryValue".toByteArray()) - st.setBytes(21, "varbinaryValue".toByteArray()) - st.setBytes(22, "tinyblobValue".toByteArray()) - st.setBytes(23, "blobValue".toByteArray()) - st.setBytes(24, "mediumblobValue".toByteArray()) - st.setBytes(25, "longblobValue".toByteArray()) - st.setString(26, null) - st.setString(27, null) - st.setString(28, "longtextValue$i") - st.setString(29, "Value$i") - st.setString(30, "Option$i") st.executeUpdate() } - }*/ + } } @AfterClass @JvmStatic fun tearDownClass() { - /*try { - connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table1") } - connection.createStatement().use { st -> st.execute("DROP TABLE IF EXISTS table2") } + try { connection.createStatement().use { st -> st.execute("DROP DATABASE IF EXISTS $TEST_DATABASE_NAME") } connection.close() } catch (e: SQLException) { e.printStackTrace() - }*/ + } } } @Test fun `basic test for reading sql tables`() { - val df1 = DataFrame.readSqlTable(connection, "table1").cast() + val df1 = DataFrame.readSqlTable(connection, "table1", limit = 5).cast() + val result = df1.filter { it[Table1MSSSQL::id] == 1 } - result[0][26] shouldBe "textValue1" + result[0][30] shouldBe "Sample1" + result[0][Table1MSSSQL::bigintColumn] shouldBe 123456789012345L + result[0][Table1MSSSQL::bitColumn] shouldBe true + result[0][Table1MSSSQL::intColumn] shouldBe 123456 + result[0][Table1MSSSQL::ntextColumn] shouldBe "Sample1 text" - /*val schema = DataFrame.getSchemaForSqlTable(connection, "table1") + val schema = DataFrame.getSchemaForSqlTable(connection, "table1") schema.columns["id"]!!.type shouldBe typeOf() - schema.columns["textCol"]!!.type shouldBe typeOf() - - val df2 = DataFrame.readSqlTable(connection, "table2").cast() - val result2 = df2.filter { it[Table2MSSQL::id] == 1 } - result2[0][26] shouldBe null - - val schema2 = DataFrame.getSchemaForSqlTable(connection, "table2") - schema2.columns["id"]!!.type shouldBe typeOf() - schema2.columns["textCol"]!!.type shouldBe typeOf()*/ + schema.columns["bigintColumn"]!!.type shouldBe typeOf() + schema.columns["binaryColumn"]!!.type shouldBe typeOf() + schema.columns["bitColumn"]!!.type shouldBe typeOf() + schema.columns["charColumn"]!!.type shouldBe typeOf() + schema.columns["dateColumn"]!!.type shouldBe typeOf() + schema.columns["datetime3Column"]!!.type shouldBe typeOf() + schema.columns["datetime2Column"]!!.type shouldBe typeOf() + schema.columns["datetimeoffset2Column"]!!.type shouldBe typeOf() + schema.columns["decimalColumn"]!!.type shouldBe typeOf() + schema.columns["floatColumn"]!!.type shouldBe typeOf() + schema.columns["imageColumn"]!!.type shouldBe typeOf() + schema.columns["intColumn"]!!.type shouldBe typeOf() + schema.columns["moneyColumn"]!!.type shouldBe typeOf() + schema.columns["ncharColumn"]!!.type shouldBe typeOf() + schema.columns["ntextColumn"]!!.type shouldBe typeOf() + schema.columns["numericColumn"]!!.type shouldBe typeOf() + schema.columns["nvarcharColumn"]!!.type shouldBe typeOf() + schema.columns["nvarcharMaxColumn"]!!.type shouldBe typeOf() + schema.columns["realColumn"]!!.type shouldBe typeOf() + schema.columns["smalldatetimeColumn"]!!.type shouldBe typeOf() + schema.columns["smallintColumn"]!!.type shouldBe typeOf() + schema.columns["smallmoneyColumn"]!!.type shouldBe typeOf() + schema.columns["timeColumn"]!!.type shouldBe typeOf() + schema.columns["timestampColumn"]!!.type shouldBe typeOf() + schema.columns["tinyintColumn"]!!.type shouldBe typeOf() + schema.columns["uniqueidentifierColumn"]!!.type shouldBe typeOf() + schema.columns["varbinaryColumn"]!!.type shouldBe typeOf() + schema.columns["varbinaryMaxColumn"]!!.type shouldBe typeOf() + schema.columns["varcharColumn"]!!.type shouldBe typeOf() + schema.columns["varcharMaxColumn"]!!.type shouldBe typeOf() + schema.columns["xmlColumn"]!!.type shouldBe typeOf() + schema.columns["sqlvariantColumn"]!!.type shouldBe typeOf() + schema.columns["geometryColumn"]!!.type shouldBe typeOf() + schema.columns["geographyColumn"]!!.type shouldBe typeOf() } @Test fun `read from sql query`() { - /* @Language("SQL") + @Language("SQL") val sqlQuery = """ - SELECT - t1.id, - t1.enumCol, - t2.setCol - FROM table1 t1 - JOIN table2 t2 ON t1.id = t2.id - """.trimIndent() + SELECT + Table1.id, + Table1.bigintColumn + FROM Table1 + """.trimIndent() - val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery).cast() - val result = df.filter { it[Table3MSSQL::id] == 1 } - result[0][2] shouldBe "Option1" + val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery, limit = 3).cast() + val result = df.filter { it[Table1MSSSQL::id] == 1 } + result[0][Table1MSSSQL::bigintColumn] shouldBe 123456789012345L val schema = DataFrame.getSchemaForSqlQuery(connection, sqlQuery = sqlQuery) schema.columns["id"]!!.type shouldBe typeOf() - schema.columns["enumCol"]!!.type shouldBe typeOf() - schema.columns["setCol"]!!.type shouldBe typeOf()*/ + schema.columns["bigintColumn"]!!.type shouldBe typeOf() } @Test fun `read from all tables`() { - /* val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 1000) - - val table1Df = dataframes[0].cast() - - table1Df.rowsCount() shouldBe 3 - table1Df.filter { it[Table1MSSSQL::integerCol] > 100 }.rowsCount() shouldBe 2 - table1Df[0][11] shouldBe 10.0 - table1Df[0][26] shouldBe "textValue1" - table1Df[0][31] shouldBe JSON_STRING // TODO: https://github.com/Kotlin/dataframe/issues/462 - - val table2Df = dataframes[1].cast() - - table2Df.rowsCount() shouldBe 3 - table2Df.filter { it[Table2MSSQL::integerCol] != null && it[Table2MSSQL::integerCol]!! > 400 } - .rowsCount() shouldBe 1 - table2Df[0][11] shouldBe 20.0 - table2Df[0][26] shouldBe null*/ - } - - @Test - fun `reading numeric types`() { - /* val df1 = DataFrame.readSqlTable(connection, "table1").cast() - - val result = df1.select("tinyintCol") - .add("tinyintCol2") { it[Table1MSSSQL::tinyintCol] } - - result[0][1] shouldBe 1 - - val result1 = df1.select("smallintCol") - .add("smallintCol2") { it[Table1MSSSQL::smallintCol] } - - result1[0][1] shouldBe 10 - - val result2 = df1.select("mediumintCol") - .add("mediumintCol2") { it[Table1MSSSQL::mediumintCol] } - - result2[0][1] shouldBe 100 - - val result3 = df1.select("mediumintUnsignedCol") - .add("mediumintUnsignedCol2") { it[Table1MSSSQL::mediumintUnsignedCol] } - - result3[0][1] shouldBe 100 + val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 4) - val result4 = df1.select("integerUnsignedCol") - .add("integerUnsignedCol2") { it[Table1MSSSQL::integerUnsignedCol] } - - result4[0][1] shouldBe 100L - - val result5 = df1.select("bigintCol") - .add("bigintCol2") { it[Table1MSSSQL::bigintCol] } - - result5[0][1] shouldBe 100 - - val result6 = df1.select("floatCol") - .add("floatCol2") { it[Table1MSSSQL::floatCol] } - - result6[0][1] shouldBe 10.0f - - val result7 = df1.select("doubleCol") - .add("doubleCol2") { it[Table1MSSSQL::doubleCol] } - - result7[0][1] shouldBe 10.0 - - val result8 = df1.select("decimalCol") - .add("decimalCol2") { it[Table1MSSSQL::decimalCol] } - - result8[0][1] shouldBe BigDecimal("10") - - val schema = DataFrame.getSchemaForSqlTable(connection, "table1") + val table1Df = dataframes[0].cast() - schema.columns["tinyintCol"]!!.type shouldBe typeOf() - schema.columns["smallintCol"]!!.type shouldBe typeOf() - schema.columns["mediumintCol"]!!.type shouldBe typeOf() - schema.columns["mediumintUnsignedCol"]!!.type shouldBe typeOf() - schema.columns["integerUnsignedCol"]!!.type shouldBe typeOf() - schema.columns["bigintCol"]!!.type shouldBe typeOf() - schema.columns["floatCol"]!!.type shouldBe typeOf() - schema.columns["doubleCol"]!!.type shouldBe typeOf() - schema.columns["decimalCol"]!!.type shouldBe typeOf()*/ + table1Df.rowsCount() shouldBe 3 + table1Df.filter { it[Table1MSSSQL::intColumn] > 2 }.rowsCount() shouldBe 3 + table1Df[0][Table1MSSSQL::bigintColumn] shouldBe 123456789012345L } } From 9cee3355280310111a09f5c7bb06e0711c930c1a Mon Sep 17 00:00:00 2001 From: zaleslaw Date: Tue, 30 Apr 2024 17:20:17 +0200 Subject: [PATCH 03/10] Refactor SQL query limit implementation across databases The SQL query limit behavior has been updated to use a unified method, `sqlQueryLimit`, on different database types, instead of hardcoding this limitation. This ensures a consistent application of these limits across different databases. Also added `TODO` comments to address the nullability issues and points to be checked like filtering system tables and special behavior with catalogues in MSSQL in the future. --- .../org/jetbrains/kotlinx/dataframe/io/db/DbType.kt | 3 ++- .../org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt | 5 ++++- .../org/jetbrains/kotlinx/dataframe/io/readJdbc.kt | 13 +++++++------ .../org/jetbrains/kotlinx/dataframe/io/h2Test.kt | 2 ++ .../org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt | 1 + 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt index ee86d5514..32abb712f 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt @@ -51,5 +51,6 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) { */ public abstract fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? - public open fun sqlQueryLimitOne(tableName: String): String = "SELECT * FROM $tableName LIMIT 1" + public open fun sqlQueryLimit(sqlQuery: String, limit: Int = 1): String = + "SELECT * FROM ($sqlQuery) as LIMIT_TABLE LIMIT $limit" } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt index 2ba9c847e..f62b0891b 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt @@ -21,10 +21,12 @@ public object MsSql : DbType("sqlserver") { return null } + // TODO: need to find solution to filter system tables override fun isSystemTable(tableMetadata: TableMetadata): Boolean { return MySql.isSystemTable(tableMetadata) } + // TODO: need to check override fun buildTableMetadata(tables: ResultSet): TableMetadata { return TableMetadata( tables.getString("table_name"), @@ -37,5 +39,6 @@ public object MsSql : DbType("sqlserver") { return null } - public override fun sqlQueryLimitOne(tableName: String): String = "SELECT TOP 1 * FROM $tableName" + public override fun sqlQueryLimit(sqlQuery: String, limit: Int): String = + "SELECT TOP $limit * FROM ($sqlQuery) as LIMIT_TABLE" } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index a36b43e14..339b035d7 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -137,13 +137,13 @@ public fun DataFrame.Companion.readSqlTable( limit: Int = DEFAULT_LIMIT, inferNullability: Boolean = true, ): AnyFrame { - var preparedQuery = "SELECT * FROM $tableName" - if (limit > 0) preparedQuery += " LIMIT $limit" - val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) - connection.createStatement().use { st -> + var preparedQuery = "SELECT * FROM $tableName" + if (limit > 0) preparedQuery = dbType.sqlQueryLimit(preparedQuery, limit) + + connection.createStatement().use { st -> logger.debug { "Connection with url:$url is established successfully." } st.executeQuery( @@ -207,7 +207,7 @@ public fun DataFrame.Companion.readSqlQuery( val dbType = extractDBTypeFromUrl(url) var internalSqlQuery = sqlQuery - if (limit > 0) internalSqlQuery += " LIMIT $limit" + if (limit > 0) internalSqlQuery = dbType.sqlQueryLimit(internalSqlQuery, limit) logger.debug { "Executing SQL query: $internalSqlQuery" } @@ -367,7 +367,8 @@ public fun DataFrame.Companion.getSchemaForSqlTable( val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) - val selectFirstRowQuery = dbType.sqlQueryLimitOne(tableName) + val sqlQuery = "SELECT * FROM $tableName" + val selectFirstRowQuery = dbType.sqlQueryLimit(sqlQuery, limit = 1) connection.createStatement().use { st -> st.executeQuery( diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt index 864a0c4da..2f1f1f9ae 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt @@ -633,6 +633,7 @@ class JdbcTest { val saleDataSchema = dataSchemas[1] saleDataSchema.columns.size shouldBe 3 + // TODO: fix nullability saleDataSchema.columns["amount"]!!.type shouldBe typeOf() val dbConfig = DatabaseConfiguration(url = URL) @@ -707,6 +708,7 @@ class JdbcTest { val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName) dataSchema.columns.size shouldBe 4 + // TODO: fix nullability dataSchema.columns["id"]!!.type shouldBe typeOf() dataSchema.columns["name"]!!.type shouldBe typeOf() dataSchema.columns["surname"]!!.type shouldBe typeOf() diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt index 3370070ca..fe2d21efa 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt @@ -270,6 +270,7 @@ class MSSQLTest { schema.columns["bigintColumn"]!!.type shouldBe typeOf() } + // TODO: special behaviour with catalogues in MSSQL? @Test fun `read from all tables`() { val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 4) From 459fb3ad9f5e17f3370dfc7b91f06cf9c49caf8c Mon Sep 17 00:00:00 2001 From: zaleslaw Date: Mon, 6 May 2024 13:44:20 +0200 Subject: [PATCH 04/10] Refactored and improved SQL query limit and nullability handling --- .../kotlinx/dataframe/io/db/DbType.kt | 10 +- .../kotlinx/dataframe/io/db/MsSql.kt | 6 +- .../jetbrains/kotlinx/dataframe/io/db/util.kt | 2 +- .../kotlinx/dataframe/io/readJdbc.kt | 27 ++-- .../jetbrains/kotlinx/dataframe/io/h2Test.kt | 3 +- .../kotlinx/dataframe/io/mssqlTest.kt | 120 +++++++++++++++++- 6 files changed, 152 insertions(+), 16 deletions(-) diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt index 32abb712f..b4998269b 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt @@ -51,6 +51,14 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) { */ public abstract fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? + + /** + * Constructs a SQL query with a limit clause. + * + * @param sqlQuery The original SQL query. + * @param limit The maximum number of rows to retrieve from the query. Default is 1. + * @return A new SQL query with the limit clause added. + */ public open fun sqlQueryLimit(sqlQuery: String, limit: Int = 1): String = - "SELECT * FROM ($sqlQuery) as LIMIT_TABLE LIMIT $limit" + "$sqlQuery LIMIT $limit" } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt index f62b0891b..52b582aa2 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt @@ -39,6 +39,8 @@ public object MsSql : DbType("sqlserver") { return null } - public override fun sqlQueryLimit(sqlQuery: String, limit: Int): String = - "SELECT TOP $limit * FROM ($sqlQuery) as LIMIT_TABLE" + public override fun sqlQueryLimit(sqlQuery: String, limit: Int): String { + sqlQuery.replace("SELECT", "SELECT TOP $limit", ignoreCase = true) + return sqlQuery + } } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt index 1c3e9c238..793b41a93 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt @@ -20,7 +20,7 @@ public fun extractDBTypeFromUrl(url: String?): DbType { MsSql.dbTypeInJdbcUrl in url -> MsSql else -> throw IllegalArgumentException( "Unsupported database type in the url: $url. " + - "Only H2, MariaDB, MySQL, SQLite and PostgreSQL are supported!" + "Only H2, MariaDB, MySQL, MSSQL, SQLite and PostgreSQL are supported!" ) } } else { diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index 339b035d7..b3a01290c 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -140,14 +140,14 @@ public fun DataFrame.Companion.readSqlTable( val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) - var preparedQuery = "SELECT * FROM $tableName" - if (limit > 0) preparedQuery = dbType.sqlQueryLimit(preparedQuery, limit) + var selectAllQuery = "SELECT * FROM $tableName" + if (limit > 0) selectAllQuery = dbType.sqlQueryLimit(selectAllQuery, limit) connection.createStatement().use { st -> logger.debug { "Connection with url:$url is established successfully." } st.executeQuery( - preparedQuery + selectAllQuery ).use { rs -> val tableColumns = getTableColumnsMetadata(rs) return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability) @@ -207,7 +207,9 @@ public fun DataFrame.Companion.readSqlQuery( val dbType = extractDBTypeFromUrl(url) var internalSqlQuery = sqlQuery - if (limit > 0) internalSqlQuery = dbType.sqlQueryLimit(internalSqlQuery, limit) + if (limit > 0) { + internalSqlQuery = dbType.sqlQueryLimit(internalSqlQuery, limit) + } logger.debug { "Executing SQL query: $internalSqlQuery" } @@ -317,9 +319,11 @@ public fun DataFrame.Companion.readAllSqlTables( val table = dbType.buildTableMetadata(tables) if (!dbType.isSystemTable(table)) { // we filter her second time because of specific logic with SQLite and possible issues with future databases - // val tableName = if (table.catalogue != null) table.catalogue + "." + table.name else table.name - val tableName = if (catalogue != null) catalogue + "." + table.name else table.name - + val tableName = when { + catalogue != null && table.schemaName != null -> "$catalogue.${table.schemaName}.${table.name}" + catalogue != null && table.schemaName == null -> "$catalogue.${table.name}" + else -> table.name + } // TODO: both cases is schema specified or not in URL // in h2 database name is recognized as a schema name https://www.h2database.com/html/features.html#database_url // https://stackoverflow.com/questions/20896935/spring-hibernate-h2-database-schema-not-found @@ -532,16 +536,21 @@ private fun getTableColumnsMetadata(rs: ResultSet): MutableList() } + // TODO: add the same test for each particular database and refactor the scenario to the common test case + // https://github.com/Kotlin/dataframe/issues/688 @Test fun `infer nullability`() { // prepare tables and data @@ -708,7 +710,6 @@ class JdbcTest { val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName) dataSchema.columns.size shouldBe 4 - // TODO: fix nullability dataSchema.columns["id"]!!.type shouldBe typeOf() dataSchema.columns["name"]!!.type shouldBe typeOf() dataSchema.columns["surname"]!!.type shouldBe typeOf() diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt index fe2d21efa..3d9c99a31 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt @@ -5,6 +5,8 @@ import org.intellij.lang.annotations.Language import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.annotations.DataSchema import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.JdbcTest.Companion +import org.jetbrains.kotlinx.dataframe.io.db.H2 import org.junit.AfterClass import org.junit.BeforeClass import org.junit.Ignore @@ -12,6 +14,7 @@ import org.junit.Test import java.math.BigDecimal import java.sql.Connection import java.sql.DriverManager +import java.sql.ResultSet import java.sql.SQLException import java.util.* import kotlin.reflect.typeOf @@ -277,8 +280,121 @@ class MSSQLTest { val table1Df = dataframes[0].cast() - table1Df.rowsCount() shouldBe 3 - table1Df.filter { it[Table1MSSSQL::intColumn] > 2 }.rowsCount() shouldBe 3 + table1Df.rowsCount() shouldBe 4 + table1Df.filter { it[Table1MSSSQL::id] > 2 }.rowsCount() shouldBe 2 table1Df[0][Table1MSSSQL::bigintColumn] shouldBe 123456789012345L } + + // TODO: add the same test for each particular database and refactor the scenario to the common test case + // https://github.com/Kotlin/dataframe/issues/688 + @Test + fun `infer nullability`() { + // prepare tables and data + @Language("SQL") + val createTestTable1Query = """ + CREATE TABLE TestTable1 ( + id INT PRIMARY KEY, + name VARCHAR(50), + surname VARCHAR(50), + age INT NOT NULL + ) + """ + + connection.createStatement().execute(createTestTable1Query) + + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)") + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)") + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)") + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)") + + // start testing `readSqlTable` method + + // with default inferNullability: Boolean = true + val tableName = "TestTable1" + val df = DataFrame.readSqlTable(connection, tableName) + df.schema().columns["id"]!!.type shouldBe typeOf() + df.schema().columns["name"]!!.type shouldBe typeOf() + df.schema().columns["surname"]!!.type shouldBe typeOf() + df.schema().columns["age"]!!.type shouldBe typeOf() + + val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName) + dataSchema.columns.size shouldBe 4 + dataSchema.columns["id"]!!.type shouldBe typeOf() + dataSchema.columns["name"]!!.type shouldBe typeOf() + dataSchema.columns["surname"]!!.type shouldBe typeOf() + dataSchema.columns["age"]!!.type shouldBe typeOf() + + // with inferNullability: Boolean = false + val df1 = DataFrame.readSqlTable(connection, tableName, inferNullability = false) + df1.schema().columns["id"]!!.type shouldBe typeOf() + df1.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls + df1.schema().columns["surname"]!!.type shouldBe typeOf() + df1.schema().columns["age"]!!.type shouldBe typeOf() + + // end testing `readSqlTable` method + + // start testing `readSQLQuery` method + + // ith default inferNullability: Boolean = true + @Language("SQL") + val sqlQuery = """ + SELECT name, surname, age FROM TestTable1 + """.trimIndent() + + val df2 = DataFrame.readSqlQuery(connection, sqlQuery) + df2.schema().columns["name"]!!.type shouldBe typeOf() + df2.schema().columns["surname"]!!.type shouldBe typeOf() + df2.schema().columns["age"]!!.type shouldBe typeOf() + + val dataSchema2 = DataFrame.getSchemaForSqlQuery(connection, sqlQuery) + dataSchema2.columns.size shouldBe 3 + dataSchema2.columns["name"]!!.type shouldBe typeOf() + dataSchema2.columns["surname"]!!.type shouldBe typeOf() + dataSchema2.columns["age"]!!.type shouldBe typeOf() + + // with inferNullability: Boolean = false + val df3 = DataFrame.readSqlQuery(connection, sqlQuery, inferNullability = false) + df3.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls + df3.schema().columns["surname"]!!.type shouldBe typeOf() + df3.schema().columns["age"]!!.type shouldBe typeOf() + + // end testing `readSQLQuery` method + + // start testing `readResultSet` method + + connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st -> + @Language("SQL") + val selectStatement = "SELECT * FROM TestTable1" + + st.executeQuery(selectStatement).use { rs -> + // ith default inferNullability: Boolean = true + val df4 = DataFrame.readResultSet(rs, H2) + df4.schema().columns["id"]!!.type shouldBe typeOf() + df4.schema().columns["name"]!!.type shouldBe typeOf() + df4.schema().columns["surname"]!!.type shouldBe typeOf() + df4.schema().columns["age"]!!.type shouldBe typeOf() + + rs.beforeFirst() + + val dataSchema3 = DataFrame.getSchemaForResultSet(rs, H2) + dataSchema3.columns.size shouldBe 4 + dataSchema3.columns["id"]!!.type shouldBe typeOf() + dataSchema3.columns["name"]!!.type shouldBe typeOf() + dataSchema3.columns["surname"]!!.type shouldBe typeOf() + dataSchema3.columns["age"]!!.type shouldBe typeOf() + + // with inferNullability: Boolean = false + rs.beforeFirst() + + val df5 = DataFrame.readResultSet(rs, H2, inferNullability = false) + df5.schema().columns["id"]!!.type shouldBe typeOf() + df5.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls + df5.schema().columns["surname"]!!.type shouldBe typeOf() + df5.schema().columns["age"]!!.type shouldBe typeOf() + } + } + // end testing `readResultSet` method + + connection.createStatement().execute("DROP TABLE TestTable1") + } } From c828ed9ca1112e8198eb96a25575aaa52df5fced Mon Sep 17 00:00:00 2001 From: zaleslaw Date: Mon, 6 May 2024 13:45:05 +0200 Subject: [PATCH 05/10] Ignore MSSQLTest class in test execution --- .../test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt | 1 + 1 file changed, 1 insertion(+) diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt index 3d9c99a31..ee3d0a87c 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt @@ -63,6 +63,7 @@ interface Table1MSSSQL { val geographyColumn: String } +@Ignore class MSSQLTest { companion object { private lateinit var connection: Connection From 1cd608d87e5bb9ac88f003998a9a925047cba89f Mon Sep 17 00:00:00 2001 From: zaleslaw Date: Mon, 6 May 2024 14:00:09 +0200 Subject: [PATCH 06/10] Add MSSQL support and clean up code This commit adds Microsoft SQL Server (MSSQL) library to the dataframe-jdbc project's dependencies. Also, system table filtering has been specifically implemented for MSSQL by adjusting the isSystemTable method. This is a significant improvement over the previous assumption that all DBMS are similar to MySql. --- dataframe-jdbc/build.gradle.kts | 3 +-- .../kotlinx/dataframe/io/db/DbType.kt | 1 - .../kotlinx/dataframe/io/db/MsSql.kt | 22 ++++++++++++++++--- .../kotlinx/dataframe/io/mssqlTest.kt | 1 - gradle/libs.versions.toml | 2 ++ 5 files changed, 22 insertions(+), 7 deletions(-) diff --git a/dataframe-jdbc/build.gradle.kts b/dataframe-jdbc/build.gradle.kts index aa7d4302a..ed29c5830 100644 --- a/dataframe-jdbc/build.gradle.kts +++ b/dataframe-jdbc/build.gradle.kts @@ -25,8 +25,7 @@ dependencies { testImplementation(libs.postgresql) testImplementation(libs.mysql) testImplementation(libs.h2db) - // TODO - testImplementation ("com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre11") + testImplementation (libs.mssql) testImplementation(libs.junit) testImplementation(libs.sl4j) testImplementation(libs.kotestAssertions) { diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt index b4998269b..aae6eb995 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt @@ -51,7 +51,6 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) { */ public abstract fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? - /** * Constructs a SQL query with a limit clause. * diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt index 52b582aa2..4dd180977 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata import org.jetbrains.kotlinx.dataframe.io.TableMetadata import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import java.sql.ResultSet +import java.util.* import kotlin.reflect.KType import kotlin.reflect.full.createType @@ -21,12 +22,27 @@ public object MsSql : DbType("sqlserver") { return null } - // TODO: need to find solution to filter system tables override fun isSystemTable(tableMetadata: TableMetadata): Boolean { - return MySql.isSystemTable(tableMetadata) + val locale = Locale.getDefault() + + fun String?.containsWithLowercase(substr: String) = this?.lowercase(locale)?.contains(substr) == true + + val schemaName = tableMetadata.schemaName + val tableName = tableMetadata.name + val catalogName = tableMetadata.catalogue + + return schemaName.containsWithLowercase("sys") || + schemaName.containsWithLowercase("information_schema") || + tableName.startsWith("sys") || + tableName.startsWith("dt") || + tableName.containsWithLowercase("sys_config") || + catalogName.containsWithLowercase("system") || + catalogName.containsWithLowercase("master") || + catalogName.containsWithLowercase("model") || + catalogName.containsWithLowercase("msdb") || + catalogName.containsWithLowercase("tempdb") } - // TODO: need to check override fun buildTableMetadata(tables: ResultSet): TableMetadata { return TableMetadata( tables.getString("table_name"), diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt index ee3d0a87c..3d9c99a31 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt @@ -63,7 +63,6 @@ interface Table1MSSSQL { val geographyColumn: String } -@Ignore class MSSQLTest { companion object { private lateinit var connection: Connection diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 402002d81..38e1e045e 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -31,6 +31,7 @@ fuel = "2.3.1" poi = "5.2.5" mariadb = "3.3.2" h2db = "2.2.224" +mssql = "12.6.1.jre11" mysql = "8.3.0" postgresql = "42.7.2" sqlite = "3.45.1.0" @@ -75,6 +76,7 @@ fuel = { group = "com.github.kittinunf.fuel", name = "fuel", version.ref = "fuel poi = { group = "org.apache.poi", name = "poi", version.ref = "poi" } mariadb = { group = "org.mariadb.jdbc", name = "mariadb-java-client", version.ref = "mariadb" } h2db = { group = "com.h2database", name = "h2", version.ref = "h2db" } +mssql = { group = "com.microsoft.sqlserver", name = "mssql-jdbc", version.ref = "mssql" } mysql = { group = "com.mysql", name = "mysql-connector-j", version.ref = "mysql" } postgresql = { group = "org.postgresql", name = "postgresql", version.ref = "postgresql" } sqlite = { group = "org.xerial", name = "sqlite-jdbc", version.ref = "sqlite" } From b1a0959da0a1c94e667fbcfe32d7bb57b855a903 Mon Sep 17 00:00:00 2001 From: zaleslaw Date: Mon, 6 May 2024 14:01:51 +0200 Subject: [PATCH 07/10] Ignore MSSQLTest class in unit tests --- .../test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt | 1 + 1 file changed, 1 insertion(+) diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt index 3d9c99a31..ee3d0a87c 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/mssqlTest.kt @@ -63,6 +63,7 @@ interface Table1MSSSQL { val geographyColumn: String } +@Ignore class MSSQLTest { companion object { private lateinit var connection: Connection From 95e617a588f612864406ae0f620d10ee697fec9a Mon Sep 17 00:00:00 2001 From: zaleslaw Date: Mon, 6 May 2024 14:07:20 +0200 Subject: [PATCH 08/10] Refactor indentation in Kotlin files --- .../kotlinx/dataframe/io/readJdbc.kt | 3 +-- .../kotlinx/dataframe/io/mssqlTest.kt | 26 +++++++++---------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index b3a01290c..f4bc10980 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -143,7 +143,7 @@ public fun DataFrame.Companion.readSqlTable( var selectAllQuery = "SELECT * FROM $tableName" if (limit > 0) selectAllQuery = dbType.sqlQueryLimit(selectAllQuery, limit) - connection.createStatement().use { st -> + connection.createStatement().use { st -> logger.debug { "Connection with url:$url is established successfully." } st.executeQuery( @@ -536,7 +536,6 @@ private fun getTableColumnsMetadata(rs: ResultSet): MutableList for (i in 1..5) { @@ -183,7 +183,8 @@ class MSSQLTest { st.setString(32, "Sample$i") // xmlColumn st.setString(33, "SQL_VARIANT") // sqlvariantColumn st.setBytes( - 34, byteArrayOf( + 34, + byteArrayOf( 0xE6.toByte(), 0x10, 0x00, 0x00, 0x01, 0x0C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x44, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x09, 0x05, 0x4C, 0x0 ) @@ -259,11 +260,11 @@ class MSSQLTest { fun `read from sql query`() { @Language("SQL") val sqlQuery = """ - SELECT - Table1.id, - Table1.bigintColumn - FROM Table1 - """.trimIndent() + SELECT + Table1.id, + Table1.bigintColumn + FROM Table1 + """.trimIndent() val df = DataFrame.readSqlQuery(connection, sqlQuery = sqlQuery, limit = 3).cast() val result = df.filter { it[Table1MSSSQL::id] == 1 } @@ -274,16 +275,15 @@ class MSSQLTest { schema.columns["bigintColumn"]!!.type shouldBe typeOf() } - // TODO: special behaviour with catalogues in MSSQL? @Test fun `read from all tables`() { - val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 4) + val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 4) - val table1Df = dataframes[0].cast() + val table1Df = dataframes[0].cast() - table1Df.rowsCount() shouldBe 4 - table1Df.filter { it[Table1MSSSQL::id] > 2 }.rowsCount() shouldBe 2 - table1Df[0][Table1MSSSQL::bigintColumn] shouldBe 123456789012345L + table1Df.rowsCount() shouldBe 4 + table1Df.filter { it[Table1MSSSQL::id] > 2 }.rowsCount() shouldBe 2 + table1Df[0][Table1MSSSQL::bigintColumn] shouldBe 123456789012345L } // TODO: add the same test for each particular database and refactor the scenario to the common test case From 1ebf6515cd793dcdd7937d3928ec85b27c4315f3 Mon Sep 17 00:00:00 2001 From: zaleslaw Date: Mon, 6 May 2024 14:10:34 +0200 Subject: [PATCH 09/10] Update comments and fix formatting in MsSql.kt and build.gradle.kts --- dataframe-jdbc/build.gradle.kts | 2 +- .../kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataframe-jdbc/build.gradle.kts b/dataframe-jdbc/build.gradle.kts index ed29c5830..6e7e8dd42 100644 --- a/dataframe-jdbc/build.gradle.kts +++ b/dataframe-jdbc/build.gradle.kts @@ -25,7 +25,7 @@ dependencies { testImplementation(libs.postgresql) testImplementation(libs.mysql) testImplementation(libs.h2db) - testImplementation (libs.mssql) + testImplementation(libs.mssql) testImplementation(libs.junit) testImplementation(libs.sl4j) testImplementation(libs.kotestAssertions) { diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt index 4dd180977..05aed59a7 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/MsSql.kt @@ -9,9 +9,9 @@ import kotlin.reflect.KType import kotlin.reflect.full.createType /** - * Represents the MariaDb database type. + * Represents the MSSQL database type. * - * This class provides methods to convert data from a ResultSet to the appropriate type for MariaDb, + * This class provides methods to convert data from a ResultSet to the appropriate type for MSSQL, * and to generate the corresponding column schema. */ public object MsSql : DbType("sqlserver") { From 698b34ab7831d41df5cfaefe9a2ca3aac0f86b50 Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Mon, 13 May 2024 17:38:31 +0200 Subject: [PATCH 10/10] Refactor code to simplify SQL query construction --- .../org/jetbrains/kotlinx/dataframe/io/readJdbc.kt | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index f4bc10980..527808fd7 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -140,8 +140,8 @@ public fun DataFrame.Companion.readSqlTable( val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) - var selectAllQuery = "SELECT * FROM $tableName" - if (limit > 0) selectAllQuery = dbType.sqlQueryLimit(selectAllQuery, limit) + val selectAllQuery = if (limit > 0) dbType.sqlQueryLimit("SELECT * FROM $tableName", limit) + else "SELECT * FROM $tableName" connection.createStatement().use { st -> logger.debug { "Connection with url:$url is established successfully." } @@ -206,10 +206,7 @@ public fun DataFrame.Companion.readSqlQuery( val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) - var internalSqlQuery = sqlQuery - if (limit > 0) { - internalSqlQuery = dbType.sqlQueryLimit(internalSqlQuery, limit) - } + val internalSqlQuery = if (limit > 0) dbType.sqlQueryLimit(sqlQuery, limit) else sqlQuery logger.debug { "Executing SQL query: $internalSqlQuery" }