Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extension functions for the ResultSet #772

Merged
merged 7 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public data class TableMetadata(val name: String, val schemaName: String?, val c
* @property [user] the username used for authentication (optional, default is empty string).
* @property [password] the password used for authentication (optional, default is empty string).
*/
public data class DatabaseConfiguration(val url: String, val user: String = "", val password: String = "")
public data class DbConnectionConfig(val url: String, val user: String = "", val password: String = "")

/**
* Reads data from an SQL table and converts it into a DataFrame.
Expand All @@ -110,7 +110,7 @@ public data class DatabaseConfiguration(val url: String, val user: String = "",
* @return the DataFrame containing the data from the SQL table.
*/
public fun DataFrame.Companion.readSqlTable(
dbConfig: DatabaseConfiguration,
dbConfig: DbConnectionConfig,
tableName: String,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
Expand Down Expand Up @@ -169,7 +169,7 @@ public fun DataFrame.Companion.readSqlTable(
* @return the DataFrame containing the result of the SQL query.
*/
public fun DataFrame.Companion.readSqlQuery(
dbConfig: DatabaseConfiguration,
dbConfig: DbConnectionConfig,
sqlQuery: String,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
Expand Down Expand Up @@ -218,6 +218,89 @@ public fun DataFrame.Companion.readSqlQuery(
}
}

/**
* Converts the result of an SQL query or SQL table (by name) to the DataFrame.
*
* @param [sqlQueryOrTableName] the SQL query to execute or name of the SQL table.
* It should be a name of one of the existing SQL tables,
* or the SQL query should start from SELECT and contain one query for reading data without any manipulation.
* It should not contain `;` symbol.
* @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame containing the result of the SQL query.
*/
public fun DbConnectionConfig.readDataFrame(
sqlQueryOrTableName: String,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame =
when {
isSqlQuery(sqlQueryOrTableName) -> DataFrame.readSqlQuery(
this,
sqlQueryOrTableName,
limit,
inferNullability,
)

isSqlTableName(sqlQueryOrTableName) -> DataFrame.readSqlTable(
this,
sqlQueryOrTableName,
limit,
inferNullability,
)

else -> throw IllegalArgumentException(
"$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!",
)
}

private fun isSqlQuery(sqlQueryOrTableName: String): Boolean {
val queryPattern = Regex("(?i)\\b(SELECT)\\b")
return queryPattern.containsMatchIn(sqlQueryOrTableName.trim())
}

private fun isSqlTableName(sqlQueryOrTableName: String): Boolean {
// Match table names with optional schema and catalog (e.g., catalog.schema.table)
val tableNamePattern = Regex("^[a-zA-Z_][a-zA-Z0-9_]*(\\.[a-zA-Z_][a-zA-Z0-9_]*){0,2}$")
return tableNamePattern.matches(sqlQueryOrTableName.trim())
}

/**
* Converts the result of an SQL query or SQL table (by name) to the DataFrame.
*
* @param [sqlQueryOrTableName] the SQL query to execute or name of the SQL table.
* It should be a name of one of the existing SQL tables,
* or the SQL query should start from SELECT and contain one query for reading data without any manipulation.
* It should not contain `;` symbol.
* @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame containing the result of the SQL query.
*/
public fun Connection.readDataFrame(
sqlQueryOrTableName: String,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame =
when {
isSqlQuery(sqlQueryOrTableName) -> DataFrame.readSqlQuery(
this,
sqlQueryOrTableName,
limit,
inferNullability,
)

isSqlTableName(sqlQueryOrTableName) -> DataFrame.readSqlTable(
this,
sqlQueryOrTableName,
limit,
inferNullability,
)

else -> throw IllegalArgumentException(
"$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!",
)
}

/** SQL query is accepted only if it starts from SELECT */
private fun isValid(sqlQuery: String): Boolean {
val normalizedSqlQuery = sqlQuery.trim().uppercase()
Expand Down Expand Up @@ -256,6 +339,30 @@ public fun DataFrame.Companion.readResultSet(
return fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit, inferNullability)
}

/**
* Reads the data from a [ResultSet][java.sql.ResultSet] and converts it into a DataFrame.
*
* A [ResultSet][java.sql.ResultSet] object maintains a cursor pointing to its current row of data.
* By default, a ResultSet object is not updatable and has a cursor that can only move forward.
* Therefore, you can iterate through it only once, from the first row to the last row.
*
* For more details, refer to the official Java documentation on [ResultSet][java.sql.ResultSet].
*
* NOTE: Reading from the [ResultSet][java.sql.ResultSet] could potentially change its state.
*
* @param [dbType] the type of database that the [ResultSet] belongs to.
* @param [limit] the maximum number of rows to read from the [ResultSet][java.sql.ResultSet].
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame generated from the [ResultSet][java.sql.ResultSet] data.
*
* [java.sql.ResultSet]: https://docs.oracle.com/javase/8/docs/api/java/sql/ResultSet.html
*/
public fun ResultSet.readDataFrame(
dbType: DbType,
zaleslaw marked this conversation as resolved.
Show resolved Hide resolved
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame = DataFrame.Companion.readResultSet(this, dbType, limit, inferNullability)

/**
* Reads the data from a [ResultSet][java.sql.ResultSet] and converts it into a DataFrame.
*
Expand Down Expand Up @@ -288,6 +395,31 @@ public fun DataFrame.Companion.readResultSet(
return readResultSet(resultSet, dbType, limit, inferNullability)
}

/**
* Reads the data from a [ResultSet][java.sql.ResultSet] and converts it into a DataFrame.
*
* A [ResultSet][java.sql.ResultSet] object maintains a cursor pointing to its current row of data.
* By default, a ResultSet object is not updatable and has a cursor that can only move forward.
* Therefore, you can iterate through it only once, from the first row to the last row.
*
* For more details, refer to the official Java documentation on [ResultSet][java.sql.ResultSet].
*
* NOTE: Reading from the [ResultSet][java.sql.ResultSet] could potentially change its state.
*
* @param [connection] the connection to the database (it's required to extract the database type)
* that the [ResultSet] belongs to.
* @param [limit] the maximum number of rows to read from the [ResultSet][java.sql.ResultSet].
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame generated from the [ResultSet][java.sql.ResultSet] data.
*
* [java.sql.ResultSet]: https://docs.oracle.com/javase/8/docs/api/java/sql/ResultSet.html
*/
public fun ResultSet.readDataFrame(
connection: Connection,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame = DataFrame.Companion.readResultSet(this, connection, limit, inferNullability)

/**
* Reads all non-system tables from a database and returns them
* as a map of SQL tables and corresponding dataframes using the provided database configuration and limit.
Expand All @@ -299,7 +431,7 @@ public fun DataFrame.Companion.readResultSet(
* @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database.
*/
public fun DataFrame.Companion.readAllSqlTables(
dbConfig: DatabaseConfiguration,
dbConfig: DbConnectionConfig,
catalogue: String? = null,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
Expand Down Expand Up @@ -366,10 +498,7 @@ public fun DataFrame.Companion.readAllSqlTables(
* @param [tableName] the name of the SQL table for which to retrieve the schema.
* @return the [DataFrameSchema] object representing the schema of the SQL table
*/
public fun DataFrame.Companion.getSchemaForSqlTable(
dbConfig: DatabaseConfiguration,
tableName: String,
): DataFrameSchema {
public fun DataFrame.Companion.getSchemaForSqlTable(dbConfig: DbConnectionConfig, tableName: String): DataFrameSchema {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return getSchemaForSqlTable(connection, tableName)
}
Expand Down Expand Up @@ -405,10 +534,7 @@ public fun DataFrame.Companion.getSchemaForSqlTable(connection: Connection, tabl
* @param [sqlQuery] the SQL query to execute and retrieve the schema from.
* @return the schema of the SQL query as a [DataFrameSchema] object.
*/
public fun DataFrame.Companion.getSchemaForSqlQuery(
dbConfig: DatabaseConfiguration,
sqlQuery: String,
): DataFrameSchema {
public fun DataFrame.Companion.getSchemaForSqlQuery(dbConfig: DbConnectionConfig, sqlQuery: String): DataFrameSchema {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return getSchemaForSqlQuery(connection, sqlQuery)
}
Expand All @@ -434,6 +560,40 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQ
}
}

/**
* Retrieves the schema of an SQL query result or the SQL table using the provided database configuration.
*
* @param [sqlQueryOrTableName] the SQL query to execute and retrieve the schema from.
* @return the schema of the SQL query as a [DataFrameSchema] object.
*/
public fun DbConnectionConfig.getDataFrameSchema(sqlQueryOrTableName: String): DataFrameSchema =
when {
isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName)

isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName)

else -> throw IllegalArgumentException(
"$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!",
)
}

/**
* Retrieves the schema of an SQL query result or the SQL table using the provided database configuration.
*
* @param [sqlQueryOrTableName] the SQL query to execute and retrieve the schema from.
* @return the schema of the SQL query as a [DataFrameSchema] object.
*/
public fun Connection.getDataFrameSchema(sqlQueryOrTableName: String): DataFrameSchema =
when {
isSqlQuery(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlQuery(this, sqlQueryOrTableName)

isSqlTableName(sqlQueryOrTableName) -> DataFrame.getSchemaForSqlTable(this, sqlQueryOrTableName)

else -> throw IllegalArgumentException(
"$sqlQueryOrTableName should be SQL query or name of one of the existing SQL tables!",
)
}

/**
* Retrieves the schema from [ResultSet].
*
Expand All @@ -448,6 +608,16 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbTyp
return buildSchemaByTableColumns(tableColumns, dbType)
}

/**
* Retrieves the schema from [ResultSet].
*
* NOTE: This function will not close connection and result set and not retrieve data from the result set.
*
* @param [dbType] the type of database that the [ResultSet] belongs to.
* @return the schema of the [ResultSet] as a [DataFrameSchema] object.
*/
public fun ResultSet.getDataFrameSchema(dbType: DbType): DataFrameSchema = DataFrame.getSchemaForResultSet(this, dbType)

/**
* Retrieves the schema from [ResultSet].
*
Expand All @@ -465,13 +635,24 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, conne
return buildSchemaByTableColumns(tableColumns, dbType)
}

/**
* Retrieves the schema from [ResultSet].
*
* NOTE: This function will not close connection and result set and not retrieve data from the result set.
*
* @param [connection] the connection to the database (it's required to extract the database type).
* @return the schema of the [ResultSet] as a [DataFrameSchema] object.
*/
public fun ResultSet.getDataFrameSchema(connection: Connection): DataFrameSchema =
DataFrame.getSchemaForResultSet(this, connection)

/**
* Retrieves the schemas of all non-system tables in the database using the provided database configuration.
*
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
* @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table.
*/
public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): Map<String, DataFrameSchema> {
public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DbConnectionConfig): Map<String, DataFrameSchema> {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return getSchemaForAllSqlTables(connection)
}
Expand Down
Loading
Loading