Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43040][SQL] Improve TimestampNTZ type support in JDBC data source #40678

Closed
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc
import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Connection, Date, Timestamp}
import java.text.SimpleDateFormat
import java.time.LocalDateTime
import java.util.Properties

import org.apache.spark.sql.Column
Expand Down Expand Up @@ -140,6 +141,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
"c0 money)").executeUpdate()
conn.prepareStatement("INSERT INTO money_types VALUES " +
"('$1,000.00')").executeUpdate()

conn.prepareStatement(s"CREATE TABLE timestamp_ntz(v timestamp)").executeUpdate()
tianhanhu marked this conversation as resolved.
Show resolved Hide resolved
conn.prepareStatement(s"""INSERT INTO timestamp_ntz VALUES
|('2013-04-05 12:01:02'),
|('2013-04-05 18:01:02.123'),
|('2013-04-05 18:01:02.123456')""".stripMargin).executeUpdate()
}

test("Type mapping for various types") {
Expand Down Expand Up @@ -381,4 +388,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(row(0).length === 1)
assert(row(0).getString(0) === "$1,000.00")
}

test("SPARK-43040: timestamp_ntz read test") {
val prop = new Properties
prop.setProperty("preferTimestampNTZ", "true")
val df = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz", prop)
val row = df.collect()
tianhanhu marked this conversation as resolved.
Show resolved Hide resolved
assert(row.length === 3)
assert(row(0).length === 1)
assert(row(0) === Row(LocalDateTime.of(2013, 4, 5, 12, 1, 2)))
tianhanhu marked this conversation as resolved.
Show resolved Hide resolved
assert(row(1) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123000000)))
assert(row(2) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123456000)))
}

test("SPARK-43040: timestamp_ntz roundtrip test") {
val prop = new Properties
prop.setProperty("preferTimestampNTZ", "true")

val sparkQuery = """
|select
| timestamp_ntz'2020-12-10 11:22:33' as col0
""".stripMargin

val df_expected = sqlContext.sql(sparkQuery)
df_expected.write.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop)

val df_actual = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop)
assert(df_actual.collect()(0) == df_expected.collect()(0))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ private[jdbc] class JDBCRDD(
stmt.setFetchSize(options.fetchSize)
stmt.setQueryTimeout(options.queryTimeout)
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
val rowsIterator =
JdbcUtils.resultSetToSparkInternalRows(rs, dialect, schema, inputMetrics)

CompletionIterator[InternalRow, Iterator[InternalRow]](
new InterruptibleIterator(context, rowsIterator), close())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros, localDateToDays, toJavaDate, toJavaTimestamp, toJavaTimestampNoRebase}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp}
import org.apache.spark.sql.connector.catalog.{Identifier, TableChange}
import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex}
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDialect}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -316,21 +316,31 @@ object JdbcUtils extends Logging with SQLConfHelper {
/**
* Convert a [[ResultSet]] into an iterator of Catalyst Rows.
*/
def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = {
def resultSetToRows(
resultSet: ResultSet,
schema: StructType): Iterator[Row] = {
resultSetToRows(resultSet, schema, NoopDialect)
}

def resultSetToRows(
resultSet: ResultSet,
schema: StructType,
dialect: JdbcDialect): Iterator[Row] = {
val inputMetrics =
Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics)
val fromRow = RowEncoder(schema).resolveAndBind().createDeserializer()
val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics)
val internalRows = resultSetToSparkInternalRows(resultSet, dialect, schema, inputMetrics)
internalRows.map(fromRow)
}

private[spark] def resultSetToSparkInternalRows(
resultSet: ResultSet,
dialect: JdbcDialect,
schema: StructType,
inputMetrics: InputMetrics): Iterator[InternalRow] = {
new NextIterator[InternalRow] {
private[this] val rs = resultSet
private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
private[this] val getters: Array[JDBCValueGetter] = makeGetters(dialect, schema)
private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))

override protected def close(): Unit = {
Expand Down Expand Up @@ -368,12 +378,17 @@ object JdbcUtils extends Logging with SQLConfHelper {
* Creates `JDBCValueGetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[InternalRow]] correctly.
*/
private def makeGetters(schema: StructType): Array[JDBCValueGetter] = {
private def makeGetters(
dialect: JdbcDialect,
schema: StructType): Array[JDBCValueGetter] = {
val replaced = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)
replaced.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
replaced.fields.map(sf => makeGetter(sf.dataType, dialect, sf.metadata))
}

private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
private def makeGetter(
dt: DataType,
dialect: JdbcDialect,
metadata: Metadata): JDBCValueGetter = dt match {
case BooleanType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setBoolean(pos, rs.getBoolean(pos + 1))
Expand Down Expand Up @@ -478,7 +493,8 @@ object JdbcUtils extends Logging with SQLConfHelper {
(rs: ResultSet, row: InternalRow, pos: Int) =>
val t = rs.getTimestamp(pos + 1)
if (t != null) {
row.setLong(pos, DateTimeUtils.fromJavaTimestampNoRebase(t))
row.setLong(pos,
DateTimeUtils.localDateTimeToMicros(dialect.convertJavaTimestampToTimestampNTZ(t)))
} else {
row.update(pos, null)
}
Expand Down Expand Up @@ -596,8 +612,8 @@ object JdbcUtils extends Logging with SQLConfHelper {

case TimestampNTZType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
val micros = localDateTimeToMicros(row.getAs[java.time.LocalDateTime](pos))
stmt.setTimestamp(pos + 1, toJavaTimestampNoRebase(micros))
stmt.setTimestamp(pos + 1,
dialect.convertTimestampNTZToJavaTimestamp(row.getAs[java.time.LocalDateTime](pos)))

case DateType =>
if (conf.datetimeJava8ApiEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.jdbc

import java.sql.{Connection, Date, Driver, Statement, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Instant, LocalDate, LocalDateTime}
import java.util

import scala.collection.mutable.ArrayBuilder
Expand All @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateTimeToMicros, toJavaTimestampNoRebase}
import org.apache.spark.sql.connector.catalog.{Identifier, TableChange}
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
Expand Down Expand Up @@ -104,6 +105,31 @@ abstract class JdbcDialect extends Serializable with Logging {
*/
def getJDBCType(dt: DataType): Option[JdbcType] = None

/**
* Convert java.sql.Timestamp to a LocalDateTime representing the same wall-clock time as the
* value stored in a remote database.
* JDBC dialects should override this function to provide implementations that suite their
* JDBC drivers.
* @param t Timestamp returned from JDBC driver getTimestamp method.
* @return A LocalDateTime representing the same wall clock time as the timestamp in database.
*/
@Since("3.5.0")
def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = {
tianhanhu marked this conversation as resolved.
Show resolved Hide resolved
DateTimeUtils.microsToLocalDateTime(DateTimeUtils.fromJavaTimestampNoRebase(t))
tianhanhu marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Converts a LocalDateTime representing a TimestampNTZ type to an
* instance of `java.sql.Timestamp`.
* @param ldt representing a TimestampNTZType.
* @return A Java Timestamp representing this LocalDateTime.
*/
@Since("3.5.0")
def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = {
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
val micros = localDateTimeToMicros(ldt)
toJavaTimestampNoRebase(micros)
}

/**
* Returns a factory for creating connections to the given JDBC URL.
* In general, creating a connection has nothing to do with JDBC partition id.
Expand Down Expand Up @@ -682,6 +708,6 @@ object JdbcDialects {
/**
* NOOP dialect object, always returning the neutral element.
*/
private object NoopDialect extends JdbcDialect {
object NoopDialect extends JdbcDialect {
override def canHandle(url : String): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.jdbc

import java.sql.{Connection, SQLException, Types}
import java.sql.{Connection, SQLException, Timestamp, Types}
import java.time.LocalDateTime
import java.util
import java.util.Locale

Expand Down Expand Up @@ -102,6 +103,14 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper {
case _ => None
}

override def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = {
t.toLocalDateTime
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the change refer: https://github.com/apache/spark/pull/40678/files#r1162437868
We can update with DateTimeUtils.localDateTimeToMicros(t.toLocalDateTime) here.

}

override def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = {
Timestamp.valueOf(ldt)
}

override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Some(JdbcType("TEXT", Types.VARCHAR))
case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
Expand Down