diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index ff5127ce350f5..f840876fc5d00 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.math.{BigDecimal => JBigDecimal} import java.sql.{Connection, Date, Timestamp} import java.text.SimpleDateFormat +import java.time.LocalDateTime import java.util.Properties import org.apache.spark.sql.Column @@ -140,6 +141,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { "c0 money)").executeUpdate() conn.prepareStatement("INSERT INTO money_types VALUES " + "('$1,000.00')").executeUpdate() + + conn.prepareStatement(s"CREATE TABLE timestamp_ntz(v timestamp)").executeUpdate() + conn.prepareStatement(s"""INSERT INTO timestamp_ntz VALUES + |('2013-04-05 12:01:02'), + |('2013-04-05 18:01:02.123'), + |('2013-04-05 18:01:02.123456')""".stripMargin).executeUpdate() } test("Type mapping for various types") { @@ -381,4 +388,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(row(0).length === 1) assert(row(0).getString(0) === "$1,000.00") } + + test("SPARK-43040: timestamp_ntz read test") { + val prop = new Properties + prop.setProperty("preferTimestampNTZ", "true") + val df = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz", prop) + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 1) + assert(row(0) === Row(LocalDateTime.of(2013, 4, 5, 12, 1, 2))) + assert(row(1) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123000000))) + assert(row(2) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123456000))) + } + + test("SPARK-43040: timestamp_ntz roundtrip test") { + val prop = new Properties + prop.setProperty("preferTimestampNTZ", "true") + + val sparkQuery = """ + |select + | timestamp_ntz'2020-12-10 11:22:33' as col0 + """.stripMargin + + val df_expected = sqlContext.sql(sparkQuery) + df_expected.write.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop) + + val df_actual = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop) + assert(df_actual.collect()(0) == df_expected.collect()(0)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 70e29f5d7195c..e241951abe392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -273,7 +273,8 @@ private[jdbc] class JDBCRDD( stmt.setFetchSize(options.fetchSize) stmt.setQueryTimeout(options.queryTimeout) rs = stmt.executeQuery() - val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) + val rowsIterator = + JdbcUtils.resultSetToSparkInternalRows(rs, dialect, schema, inputMetrics) CompletionIterator[InternalRow, Iterator[InternalRow]]( new InterruptibleIterator(context, rowsIterator), close()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index fe53ba91d9592..d907ce6b100cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -38,12 +38,12 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, GenericArrayData} -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros, localDateToDays, toJavaDate, toJavaTimestamp, toJavaTimestampNoRebase} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.{Identifier, TableChange} import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex} import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDialect} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.types.UTF8String @@ -316,21 +316,31 @@ object JdbcUtils extends Logging with SQLConfHelper { /** * Convert a [[ResultSet]] into an iterator of Catalyst Rows. */ - def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = { + def resultSetToRows( + resultSet: ResultSet, + schema: StructType): Iterator[Row] = { + resultSetToRows(resultSet, schema, NoopDialect) + } + + def resultSetToRows( + resultSet: ResultSet, + schema: StructType, + dialect: JdbcDialect): Iterator[Row] = { val inputMetrics = Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics) val fromRow = RowEncoder(schema).resolveAndBind().createDeserializer() - val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics) + val internalRows = resultSetToSparkInternalRows(resultSet, dialect, schema, inputMetrics) internalRows.map(fromRow) } private[spark] def resultSetToSparkInternalRows( resultSet: ResultSet, + dialect: JdbcDialect, schema: StructType, inputMetrics: InputMetrics): Iterator[InternalRow] = { new NextIterator[InternalRow] { private[this] val rs = resultSet - private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema) + private[this] val getters: Array[JDBCValueGetter] = makeGetters(dialect, schema) private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) override protected def close(): Unit = { @@ -368,12 +378,17 @@ object JdbcUtils extends Logging with SQLConfHelper { * Creates `JDBCValueGetter`s according to [[StructType]], which can set * each value from `ResultSet` to each field of [[InternalRow]] correctly. */ - private def makeGetters(schema: StructType): Array[JDBCValueGetter] = { + private def makeGetters( + dialect: JdbcDialect, + schema: StructType): Array[JDBCValueGetter] = { val replaced = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) - replaced.fields.map(sf => makeGetter(sf.dataType, sf.metadata)) + replaced.fields.map(sf => makeGetter(sf.dataType, dialect, sf.metadata)) } - private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match { + private def makeGetter( + dt: DataType, + dialect: JdbcDialect, + metadata: Metadata): JDBCValueGetter = dt match { case BooleanType => (rs: ResultSet, row: InternalRow, pos: Int) => row.setBoolean(pos, rs.getBoolean(pos + 1)) @@ -478,7 +493,8 @@ object JdbcUtils extends Logging with SQLConfHelper { (rs: ResultSet, row: InternalRow, pos: Int) => val t = rs.getTimestamp(pos + 1) if (t != null) { - row.setLong(pos, DateTimeUtils.fromJavaTimestampNoRebase(t)) + row.setLong(pos, + DateTimeUtils.localDateTimeToMicros(dialect.convertJavaTimestampToTimestampNTZ(t))) } else { row.update(pos, null) } @@ -596,8 +612,8 @@ object JdbcUtils extends Logging with SQLConfHelper { case TimestampNTZType => (stmt: PreparedStatement, row: Row, pos: Int) => - val micros = localDateTimeToMicros(row.getAs[java.time.LocalDateTime](pos)) - stmt.setTimestamp(pos + 1, toJavaTimestampNoRebase(micros)) + stmt.setTimestamp(pos + 1, + dialect.convertTimestampNTZToJavaTimestamp(row.getAs[java.time.LocalDateTime](pos))) case DateType => if (conf.datetimeJava8ApiEnabled) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index e7a74ee3aa9c6..93a311be2f867 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Date, Driver, Statement, Timestamp} -import java.time.{Instant, LocalDate} +import java.time.{Instant, LocalDate, LocalDateTime} import java.util import scala.collection.mutable.ArrayBuilder @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateTimeToMicros, toJavaTimestampNoRebase} import org.apache.spark.sql.connector.catalog.{Identifier, TableChange} import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction @@ -104,6 +105,31 @@ abstract class JdbcDialect extends Serializable with Logging { */ def getJDBCType(dt: DataType): Option[JdbcType] = None + /** + * Convert java.sql.Timestamp to a LocalDateTime representing the same wall-clock time as the + * value stored in a remote database. + * JDBC dialects should override this function to provide implementations that suite their + * JDBC drivers. + * @param t Timestamp returned from JDBC driver getTimestamp method. + * @return A LocalDateTime representing the same wall clock time as the timestamp in database. + */ + @Since("3.5.0") + def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = { + DateTimeUtils.microsToLocalDateTime(DateTimeUtils.fromJavaTimestampNoRebase(t)) + } + + /** + * Converts a LocalDateTime representing a TimestampNTZ type to an + * instance of `java.sql.Timestamp`. + * @param ldt representing a TimestampNTZType. + * @return A Java Timestamp representing this LocalDateTime. + */ + @Since("3.5.0") + def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = { + val micros = localDateTimeToMicros(ldt) + toJavaTimestampNoRebase(micros) + } + /** * Returns a factory for creating connections to the given JDBC URL. * In general, creating a connection has nothing to do with JDBC partition id. @@ -682,6 +708,6 @@ object JdbcDialects { /** * NOOP dialect object, always returning the neutral element. */ -private object NoopDialect extends JdbcDialect { +object NoopDialect extends JdbcDialect { override def canHandle(url : String): Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index b53a0e66ba752..b42d575ae2d47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, SQLException, Types} +import java.sql.{Connection, SQLException, Timestamp, Types} +import java.time.LocalDateTime import java.util import java.util.Locale @@ -102,6 +103,14 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { case _ => None } + override def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = { + t.toLocalDateTime + } + + override def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = { + Timestamp.valueOf(ldt) + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("TEXT", Types.VARCHAR)) case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))