Skip to content

Commit

Permalink
[SPARK-43040][SQL] Improve TimestampNTZ type support in JDBC data source
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

#36726 supports TimestampNTZ type in JDBC data source and #37013 applies a fix to pass more test cases with H2.

The problem is that Java Timestamp is a poorly defined class and different JDBC drivers implement "getTimestamp" and "setTimestamp" with different expected behaviors in mind. The general conversion implementation would work with some JDBC dialects and their drivers but not others. This issue is discovered when testing with PostgreSQL database.

This PR adds a `dialect` parameter to `makeGetter` for applying dialect specific conversions when reading a Java Timestamp into TimestampNTZType. `makeSetter` already has a `dialect` field and we will use that for converting back to Java Timestamp.

### Why are the changes needed?

Fix TimestampNTZ support for PostgreSQL. Allows other JDBC dialects to provide dialect specific implementation for
converting between Java Timestamp and Spark TimestampNTZType.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing unit test.
I added new test cases for `PostgresIntegrationSuite` to cover TimestampNTZ read and writes.

Closes #40678 from tianhanhu/SPARK-43040_jdbc_timestamp_ntz.

Authored-by: tianhanhu <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
tianhanhu authored and cloud-fan committed May 5, 2023
1 parent e613563 commit 0c4ac71
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc
import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Connection, Date, Timestamp}
import java.text.SimpleDateFormat
import java.time.LocalDateTime
import java.util.Properties

import org.apache.spark.sql.Column
Expand Down Expand Up @@ -140,6 +141,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
"c0 money)").executeUpdate()
conn.prepareStatement("INSERT INTO money_types VALUES " +
"('$1,000.00')").executeUpdate()

conn.prepareStatement(s"CREATE TABLE timestamp_ntz(v timestamp)").executeUpdate()
conn.prepareStatement(s"""INSERT INTO timestamp_ntz VALUES
|('2013-04-05 12:01:02'),
|('2013-04-05 18:01:02.123'),
|('2013-04-05 18:01:02.123456')""".stripMargin).executeUpdate()
}

test("Type mapping for various types") {
Expand Down Expand Up @@ -381,4 +388,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(row(0).length === 1)
assert(row(0).getString(0) === "$1,000.00")
}

test("SPARK-43040: timestamp_ntz read test") {
val prop = new Properties
prop.setProperty("preferTimestampNTZ", "true")
val df = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz", prop)
val row = df.collect()
assert(row.length === 3)
assert(row(0).length === 1)
assert(row(0) === Row(LocalDateTime.of(2013, 4, 5, 12, 1, 2)))
assert(row(1) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123000000)))
assert(row(2) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123456000)))
}

test("SPARK-43040: timestamp_ntz roundtrip test") {
val prop = new Properties
prop.setProperty("preferTimestampNTZ", "true")

val sparkQuery = """
|select
| timestamp_ntz'2020-12-10 11:22:33' as col0
""".stripMargin

val df_expected = sqlContext.sql(sparkQuery)
df_expected.write.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop)

val df_actual = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop)
assert(df_actual.collect()(0) == df_expected.collect()(0))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ private[jdbc] class JDBCRDD(
stmt.setFetchSize(options.fetchSize)
stmt.setQueryTimeout(options.queryTimeout)
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
val rowsIterator =
JdbcUtils.resultSetToSparkInternalRows(rs, dialect, schema, inputMetrics)

CompletionIterator[InternalRow, Iterator[InternalRow]](
new InterruptibleIterator(context, rowsIterator), close())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros, localDateToDays, toJavaDate, toJavaTimestamp, toJavaTimestampNoRebase}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp}
import org.apache.spark.sql.connector.catalog.{Identifier, TableChange}
import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex}
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDialect}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -316,21 +316,31 @@ object JdbcUtils extends Logging with SQLConfHelper {
/**
* Convert a [[ResultSet]] into an iterator of Catalyst Rows.
*/
def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = {
def resultSetToRows(
resultSet: ResultSet,
schema: StructType): Iterator[Row] = {
resultSetToRows(resultSet, schema, NoopDialect)
}

def resultSetToRows(
resultSet: ResultSet,
schema: StructType,
dialect: JdbcDialect): Iterator[Row] = {
val inputMetrics =
Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics)
val fromRow = RowEncoder(schema).resolveAndBind().createDeserializer()
val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics)
val internalRows = resultSetToSparkInternalRows(resultSet, dialect, schema, inputMetrics)
internalRows.map(fromRow)
}

private[spark] def resultSetToSparkInternalRows(
resultSet: ResultSet,
dialect: JdbcDialect,
schema: StructType,
inputMetrics: InputMetrics): Iterator[InternalRow] = {
new NextIterator[InternalRow] {
private[this] val rs = resultSet
private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
private[this] val getters: Array[JDBCValueGetter] = makeGetters(dialect, schema)
private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))

override protected def close(): Unit = {
Expand Down Expand Up @@ -368,12 +378,17 @@ object JdbcUtils extends Logging with SQLConfHelper {
* Creates `JDBCValueGetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[InternalRow]] correctly.
*/
private def makeGetters(schema: StructType): Array[JDBCValueGetter] = {
private def makeGetters(
dialect: JdbcDialect,
schema: StructType): Array[JDBCValueGetter] = {
val replaced = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)
replaced.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
replaced.fields.map(sf => makeGetter(sf.dataType, dialect, sf.metadata))
}

private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
private def makeGetter(
dt: DataType,
dialect: JdbcDialect,
metadata: Metadata): JDBCValueGetter = dt match {
case BooleanType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setBoolean(pos, rs.getBoolean(pos + 1))
Expand Down Expand Up @@ -478,7 +493,8 @@ object JdbcUtils extends Logging with SQLConfHelper {
(rs: ResultSet, row: InternalRow, pos: Int) =>
val t = rs.getTimestamp(pos + 1)
if (t != null) {
row.setLong(pos, DateTimeUtils.fromJavaTimestampNoRebase(t))
row.setLong(pos,
DateTimeUtils.localDateTimeToMicros(dialect.convertJavaTimestampToTimestampNTZ(t)))
} else {
row.update(pos, null)
}
Expand Down Expand Up @@ -596,8 +612,8 @@ object JdbcUtils extends Logging with SQLConfHelper {

case TimestampNTZType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
val micros = localDateTimeToMicros(row.getAs[java.time.LocalDateTime](pos))
stmt.setTimestamp(pos + 1, toJavaTimestampNoRebase(micros))
stmt.setTimestamp(pos + 1,
dialect.convertTimestampNTZToJavaTimestamp(row.getAs[java.time.LocalDateTime](pos)))

case DateType =>
if (conf.datetimeJava8ApiEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.jdbc

import java.sql.{Connection, Date, Driver, Statement, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Instant, LocalDate, LocalDateTime}
import java.util

import scala.collection.mutable.ArrayBuilder
Expand All @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateTimeToMicros, toJavaTimestampNoRebase}
import org.apache.spark.sql.connector.catalog.{Identifier, TableChange}
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
Expand Down Expand Up @@ -104,6 +105,31 @@ abstract class JdbcDialect extends Serializable with Logging {
*/
def getJDBCType(dt: DataType): Option[JdbcType] = None

/**
* Convert java.sql.Timestamp to a LocalDateTime representing the same wall-clock time as the
* value stored in a remote database.
* JDBC dialects should override this function to provide implementations that suite their
* JDBC drivers.
* @param t Timestamp returned from JDBC driver getTimestamp method.
* @return A LocalDateTime representing the same wall clock time as the timestamp in database.
*/
@Since("3.5.0")
def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = {
DateTimeUtils.microsToLocalDateTime(DateTimeUtils.fromJavaTimestampNoRebase(t))
}

/**
* Converts a LocalDateTime representing a TimestampNTZ type to an
* instance of `java.sql.Timestamp`.
* @param ldt representing a TimestampNTZType.
* @return A Java Timestamp representing this LocalDateTime.
*/
@Since("3.5.0")
def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = {
val micros = localDateTimeToMicros(ldt)
toJavaTimestampNoRebase(micros)
}

/**
* Returns a factory for creating connections to the given JDBC URL.
* In general, creating a connection has nothing to do with JDBC partition id.
Expand Down Expand Up @@ -682,6 +708,6 @@ object JdbcDialects {
/**
* NOOP dialect object, always returning the neutral element.
*/
private object NoopDialect extends JdbcDialect {
object NoopDialect extends JdbcDialect {
override def canHandle(url : String): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.jdbc

import java.sql.{Connection, SQLException, Types}
import java.sql.{Connection, SQLException, Timestamp, Types}
import java.time.LocalDateTime
import java.util
import java.util.Locale

Expand Down Expand Up @@ -102,6 +103,14 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper {
case _ => None
}

override def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = {
t.toLocalDateTime
}

override def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = {
Timestamp.valueOf(ldt)
}

override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Some(JdbcType("TEXT", Types.VARCHAR))
case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
Expand Down

0 comments on commit 0c4ac71

Please sign in to comment.