diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index 3b83bf5bc1473..dec7bc36116ba 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -103,7 +103,7 @@ logging into the data sources. (none) A prefix that will form the final query together with query. - As the specified query will be parenthesized as a subquery in the FROM clause and some databases do not + As the specified query will be parenthesized as a subquery in the FROM clause and some databases do not support all clauses in subqueries, the prepareQuery property offers a way to run such complex queries. As an example, spark will issue a query of the following form to the JDBC Source.

<prepareQuery> SELECT <columns> FROM (<user_specified_query>) spark_gen_alias

@@ -340,10 +340,19 @@ logging into the data sources. The name of the JDBC connection provider to use to connect to this URL, e.g. db2, mssql. Must be one of the providers loaded with the JDBC data source. Used to disambiguate when more than one provider can handle - the specified driver and options. The selected provider must not be disabled by spark.sql.sources.disabledJdbcConnProviderList. + the specified driver and options. The selected provider must not be disabled by spark.sql.sources.disabledJdbcConnProviderList. read/write - + + + inferTimestampNTZType + false + + When the option is set to true, all timestamps are inferred as TIMESTAMP WITHOUT TIME ZONE. + Otherwise, timestamps are read as TIMESTAMP with local time zone. + + read + Note that kerberos authentication with keytab is not always supported by the JDBC driver.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index df21a9820f9bf..80675c7dc4711 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -226,6 +226,9 @@ class JDBCOptions( // The prefix that is added to the query sent to the JDBC database. // This is required to support some complex queries with some JDBC databases. val prepareQuery = parameters.get(JDBC_PREPARE_QUERY).map(_ + " ").getOrElse("") + + // Infers timestamp values as TimestampNTZ type when reading data. + val inferTimestampNTZType = parameters.getOrElse(JDBC_INFER_TIMESTAMP_NTZ, "false").toBoolean } class JdbcOptionsInWrite( @@ -287,4 +290,5 @@ object JDBCOptions { val JDBC_REFRESH_KRB5_CONFIG = newOption("refreshKrb5Config") val JDBC_CONNECTION_PROVIDER = newOption("connectionProvider") val JDBC_PREPARE_QUERY = newOption("prepareQuery") + val JDBC_INFER_TIMESTAMP_NTZ = newOption("inferTimestampNTZType") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 27d2d9c84c344..e95fe280c760e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -67,7 +67,8 @@ object JDBCRDD extends Logging { statement.setQueryTimeout(options.queryTimeout) val rs = statement.executeQuery() try { - JdbcUtils.getSchema(rs, dialect, alwaysNullable = true) + JdbcUtils.getSchema(rs, dialect, alwaysNullable = true, + isTimestampNTZ = options.inferTimestampNTZType) } finally { rs.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 1f17d4f0b14cf..cc8746ea5c407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex} import org.apache.spark.sql.connector.expressions.NamedReference @@ -150,6 +150,10 @@ object JdbcUtils extends Logging with SQLConfHelper { case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + // This is a common case of timestamp without time zone. Most of the databases either only + // support TIMESTAMP type or use TIMESTAMP as an alias for TIMESTAMP WITHOUT TIME ZONE. + // Note that some dialects override this setting, e.g. as SQL Server. + case TimestampNTZType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) case t: DecimalType => Option( JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) @@ -173,7 +177,8 @@ object JdbcUtils extends Logging with SQLConfHelper { sqlType: Int, precision: Int, scale: Int, - signed: Boolean): DataType = { + signed: Boolean, + isTimestampNTZ: Boolean): DataType = { val answer = sqlType match { // scalastyle:off case java.sql.Types.ARRAY => null @@ -215,6 +220,8 @@ object JdbcUtils extends Logging with SQLConfHelper { case java.sql.Types.TIME => TimestampType case java.sql.Types.TIME_WITH_TIMEZONE => null + case java.sql.Types.TIMESTAMP + if isTimestampNTZ => TimestampNTZType case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => null @@ -243,7 +250,8 @@ object JdbcUtils extends Logging with SQLConfHelper { conn.prepareStatement(options.prepareQuery + dialect.getSchemaQuery(options.tableOrQuery)) try { statement.setQueryTimeout(options.queryTimeout) - Some(getSchema(statement.executeQuery(), dialect)) + Some(getSchema(statement.executeQuery(), dialect, + isTimestampNTZ = options.inferTimestampNTZType)) } catch { case _: SQLException => None } finally { @@ -258,13 +266,15 @@ object JdbcUtils extends Logging with SQLConfHelper { * Takes a [[ResultSet]] and returns its Catalyst schema. * * @param alwaysNullable If true, all the columns are nullable. + * @param isTimestampNTZ If true, all timestamp columns are interpreted as TIMESTAMP_NTZ. * @return A [[StructType]] giving the Catalyst schema. * @throws SQLException if the schema contains an unsupported type. */ def getSchema( resultSet: ResultSet, dialect: JdbcDialect, - alwaysNullable: Boolean = false): StructType = { + alwaysNullable: Boolean = false, + isTimestampNTZ: Boolean = false): StructType = { val rsmd = resultSet.getMetaData val ncols = rsmd.getColumnCount val fields = new Array[StructField](ncols) @@ -306,7 +316,7 @@ object JdbcUtils extends Logging with SQLConfHelper { val columnType = dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( - getCatalystType(dataType, fieldSize, fieldScale, isSigned)) + getCatalystType(dataType, fieldSize, fieldScale, isSigned, isTimestampNTZ)) fields(i) = StructField(columnName, columnType, nullable, metadata.build()) i = i + 1 } @@ -463,7 +473,7 @@ object JdbcUtils extends Logging with SQLConfHelper { } } - case TimestampType => + case TimestampType | TimestampNTZType => (rs: ResultSet, row: InternalRow, pos: Int) => val t = rs.getTimestamp(pos + 1) if (t != null) { @@ -583,6 +593,18 @@ object JdbcUtils extends Logging with SQLConfHelper { stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos)) } + case TimestampNTZType => + if (conf.datetimeJava8ApiEnabled) { + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setTimestamp(pos + 1, toJavaTimestamp(instantToMicros(row.getAs[Instant](pos)))) + } else { + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setTimestamp( + pos + 1, + toJavaTimestamp(localDateTimeToMicros(row.getAs[java.time.LocalDateTime](pos))) + ) + } + case DateType => if (conf.datetimeJava8ApiEnabled) { (stmt: PreparedStatement, row: Row, pos: Int) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index a42129dbe8da8..c95489a28761b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -98,6 +98,7 @@ private object MsSqlServerDialect extends JdbcDialect { override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) + case TimestampNTZType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) case StringType => Some(JdbcType("NVARCHAR(MAX)", java.sql.Types.NVARCHAR)) case BooleanType => Some(JdbcType("BIT", java.sql.Types.BIT)) case BinaryType => Some(JdbcType("VARBINARY(MAX)", java.sql.Types.VARBINARY)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index c96b27ee7f3fa..a07ef5ecd3009 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal import java.sql.{Date, DriverManager, SQLException, Timestamp} -import java.time.{Instant, LocalDate} +import java.time.{Instant, LocalDate, LocalDateTime} import java.util.{Calendar, GregorianCalendar, Properties, TimeZone} import scala.collection.JavaConverters._ @@ -1230,6 +1230,7 @@ class JDBCSuite extends QueryTest assert(getJdbcType(oracleDialect, BinaryType) == "BLOB") assert(getJdbcType(oracleDialect, DateType) == "DATE") assert(getJdbcType(oracleDialect, TimestampType) == "TIMESTAMP") + assert(getJdbcType(oracleDialect, TimestampNTZType) == "TIMESTAMP") } private def assertEmptyQuery(sqlString: String): Unit = { @@ -1879,5 +1880,53 @@ class JDBCSuite extends QueryTest val fields = schema.fields assert(fields.length === 1) assert(fields(0).dataType === StringType) - } + } + + test("SPARK-39339: Handle TimestampNTZType null values") { + val tableName = "timestamp_ntz_null_table" + + val df = Seq(null.asInstanceOf[LocalDateTime]).toDF("col1") + + df.write.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", tableName).save() + + val res = spark.read.format("jdbc") + .option("inferTimestampNTZType", "true") + .option("url", urlWithUserAndPass) + .option("dbtable", tableName) + .load() + + checkAnswer(res, Seq(Row(null))) + } + + test("SPARK-39339: TimestampNTZType with different local time zones") { + val tableName = "timestamp_ntz_diff_tz_support_table" + + DateTimeTestUtils.outstandingZoneIds.foreach { zoneId => + DateTimeTestUtils.withDefaultTimeZone(zoneId) { + Seq( + "1972-07-04 03:30:00", + "2019-01-20 12:00:00.502", + "2019-01-20T00:00:00.123456", + "1500-01-20T00:00:00.123456" + ).foreach { case datetime => + val df = spark.sql(s"select timestamp_ntz '$datetime'") + df.write.format("jdbc") + .mode("overwrite") + .option("url", urlWithUserAndPass) + .option("dbtable", tableName) + .save() + + val res = spark.read.format("jdbc") + .option("inferTimestampNTZType", "true") + .option("url", urlWithUserAndPass) + .option("dbtable", tableName) + .load() + + checkAnswer(res, df) + } + } + } + } }