Skip to content

Commit

Permalink
explicit time formatters
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenliang123 committed Dec 4, 2023
1 parent 3619589 commit f35ba84
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.hive.service.rpc.thrift.{TFetchResultsResp, TGetResultSetMetad
import org.apache.spark.kyuubi.{SparkProgressMonitor, SQLOperationListener}
import org.apache.spark.kyuubi.SparkUtilsHelper.redact
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.{HiveResult, SQLExecution}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.{KyuubiSQLException, Utils}
Expand Down Expand Up @@ -252,7 +252,8 @@ abstract class SparkOperation(session: Session)
RowSet.toTRowSet(
taken.toSeq.asInstanceOf[Seq[Row]],
resultSchema,
getProtocolVersion)
getProtocolVersion,
HiveResult.getTimeFormatters)
}
resultRowSet.setStartRowOffset(iter.getPosition)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,18 @@ import scala.collection.JavaConverters._
import org.apache.hive.service.rpc.thrift._
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.execution.HiveResult.TimeFormatters
import org.apache.spark.sql.types._

import org.apache.kyuubi.util.RowSetUtils._

object RowSet {

def toHiveString(valueAndType: (Any, DataType), nested: Boolean = false): String = {
def toHiveString(
valueAndType: (Any, DataType),
nested: Boolean = false,
timeFormatters: TimeFormatters): String = {
// compatible w/ Spark 3.1 and above
val timeFormatters = HiveResult.getTimeFormatters
HiveResult.toHiveString(valueAndType, nested, timeFormatters)
}

Expand Down Expand Up @@ -60,15 +63,16 @@ object RowSet {
def toTRowSet(
rows: Seq[Row],
schema: StructType,
protocolVersion: TProtocolVersion): TRowSet = {
protocolVersion: TProtocolVersion,
timeFormatters: TimeFormatters): TRowSet = {
if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
toRowBasedSet(rows, schema)
toRowBasedSet(rows, schema, timeFormatters)
} else {
toColumnBasedSet(rows, schema)
toColumnBasedSet(rows, schema, timeFormatters)
}
}

def toRowBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
def toRowBasedSet(rows: Seq[Row], schema: StructType, timeFormatters: TimeFormatters): TRowSet = {
val rowSize = rows.length
val tRows = new java.util.ArrayList[TRow](rowSize)
var i = 0
Expand All @@ -78,7 +82,7 @@ object RowSet {
var j = 0
val columnSize = row.length
while (j < columnSize) {
val columnValue = toTColumnValue(j, row, schema)
val columnValue = toTColumnValue(j, row, schema, timeFormatters)
tRow.addToColVals(columnValue)
j += 1
}
Expand All @@ -88,21 +92,28 @@ object RowSet {
new TRowSet(0, tRows)
}

def toColumnBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
def toColumnBasedSet(
rows: Seq[Row],
schema: StructType,
timeFormatters: TimeFormatters): TRowSet = {
val rowSize = rows.length
val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize))
var i = 0
val columnSize = schema.length
while (i < columnSize) {
val field = schema(i)
val tColumn = toTColumn(rows, i, field.dataType)
val tColumn = toTColumn(rows, i, field.dataType, timeFormatters)
tRowSet.addToColumns(tColumn)
i += 1
}
tRowSet
}

private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType): TColumn = {
private def toTColumn(
rows: Seq[Row],
ordinal: Int,
typ: DataType,
timeFormatters: TimeFormatters): TColumn = {
val nulls = new java.util.BitSet()
typ match {
case BooleanType =>
Expand Down Expand Up @@ -152,7 +163,7 @@ object RowSet {
while (i < rowSize) {
val row = rows(i)
nulls.set(i, row.isNullAt(ordinal))
values.add(toHiveString(row.get(ordinal) -> typ))
values.add(toHiveString(row.get(ordinal) -> typ, timeFormatters = timeFormatters))
i += 1
}
TColumn.stringVal(new TStringColumn(values, nulls))
Expand Down Expand Up @@ -184,7 +195,8 @@ object RowSet {
private def toTColumnValue(
ordinal: Int,
row: Row,
types: StructType): TColumnValue = {
types: StructType,
timeFormatters: TimeFormatters): TColumnValue = {
types(ordinal).dataType match {
case BooleanType =>
val boolValue = new TBoolValue
Expand Down Expand Up @@ -232,7 +244,9 @@ object RowSet {
case _ =>
val tStrValue = new TStringValue
if (!row.isNullAt(ordinal)) {
tStrValue.setValue(toHiveString(row.get(ordinal) -> types(ordinal).dataType))
tStrValue.setValue(toHiveString(
row.get(ordinal) -> types(ordinal).dataType,
timeFormatters = timeFormatters))
}
TColumnValue.stringVal(tStrValue)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.network.util.{ByteUnit, JavaUtils}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{CollectLimitExec, HiveResult, LocalTableScanExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand Down Expand Up @@ -106,16 +106,17 @@ object SparkDatasetHelper extends Logging {

// an udf to call `RowSet.toHiveString` on complex types(struct/array/map) and timestamp type.
val toHiveStringUDF = udf[String, Row, String]((row, schemaDDL) => {
val timeFormatters = HiveResult.getTimeFormatters
val dt = DataType.fromDDL(schemaDDL)
dt match {
case StructType(Array(StructField(_, st: StructType, _, _))) =>
RowSet.toHiveString((row, st), nested = true)
RowSet.toHiveString((row, st), nested = true, timeFormatters = timeFormatters)
case StructType(Array(StructField(_, at: ArrayType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, at), nested = true)
RowSet.toHiveString((row.toSeq.head, at), nested = true, timeFormatters = timeFormatters)
case StructType(Array(StructField(_, mt: MapType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, mt), nested = true)
RowSet.toHiveString((row.toSeq.head, mt), nested = true, timeFormatters = timeFormatters)
case StructType(Array(StructField(_, tt: TimestampType, _, _))) =>
RowSet.toHiveString((row.toSeq.head, tt), nested = true)
RowSet.toHiveString((row.toSeq.head, tt), nested = true, timeFormatters = timeFormatters)
case _ =>
throw new UnsupportedOperationException
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ trait RowSetHelper {
val dateVal = Date.valueOf(s"2018-11-$day")
val timestampVal = Timestamp.valueOf(s"2018-11-17 13:33:33.$value")
val binaryVal = Array.fill[Byte](value)(value.toByte)
val arrVal = Array.fill(value)(doubleVal).toSeq
val arrVal = Array.fill(10)(doubleVal).toSeq
val mapVal = Map(value -> doubleVal)
val interval = new CalendarInterval(value, value, value)
val localDate = LocalDate.of(2018, 11, 17)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._

import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand All @@ -34,7 +35,8 @@ class RowSetSuite extends KyuubiFunSuite with RowSetHelper {
private val rows: Seq[Row] = (0 to 10).map(genRow) ++ Seq(Row.fromSeq(Seq.fill(17)(null)))

test("column based set") {
val tRowSet = RowSet.toColumnBasedSet(rows, schema)
val timeFormatters = HiveResult.getTimeFormatters
val tRowSet = RowSet.toColumnBasedSet(rows, schema, timeFormatters)
assert(tRowSet.getColumns.size() === schema.size)
assert(tRowSet.getRowsSize === 0)

Expand Down Expand Up @@ -101,14 +103,18 @@ class RowSetSuite extends KyuubiFunSuite with RowSetHelper {
dateCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "NULL")
case (b, i) =>
assert(b === RowSet.toHiveString(Date.valueOf(s"2018-11-${i + 1}") -> DateType))
assert(b === RowSet.toHiveString(
Date.valueOf(s"2018-11-${i + 1}") -> DateType,
timeFormatters = timeFormatters))
}

val tsCol = cols.next().getStringVal
tsCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b ===
RowSet.toHiveString(Timestamp.valueOf(s"2018-11-17 13:33:33.$i") -> TimestampType))
RowSet.toHiveString(
Timestamp.valueOf(s"2018-11-17 13:33:33.$i") -> TimestampType,
timeFormatters = timeFormatters))
}

val binCol = cols.next().getBinaryVal
Expand All @@ -121,14 +127,16 @@ class RowSetSuite extends KyuubiFunSuite with RowSetHelper {
arrCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === RowSet.toHiveString(
Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq -> ArrayType(DoubleType)))
Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq -> ArrayType(DoubleType),
timeFormatters = timeFormatters))
}

val mapCol = cols.next().getStringVal
mapCol.getValues.asScala.zipWithIndex.foreach {
case (b, 11) => assert(b === "NULL")
case (b, i) => assert(b === RowSet.toHiveString(
Map(i -> java.lang.Double.valueOf(s"$i.$i")) -> MapType(IntegerType, DoubleType)))
Map(i -> java.lang.Double.valueOf(s"$i.$i")) -> MapType(IntegerType, DoubleType),
timeFormatters = timeFormatters))
}

val intervalCol = cols.next().getStringVal
Expand All @@ -139,7 +147,8 @@ class RowSetSuite extends KyuubiFunSuite with RowSetHelper {
}

test("row based set") {
val tRowSet = RowSet.toRowBasedSet(rows, schema)
val timeFormatters = HiveResult.getTimeFormatters
val tRowSet = RowSet.toRowBasedSet(rows, schema, timeFormatters)
assert(tRowSet.getColumnCount === 0)
assert(tRowSet.getRowsSize === rows.size)
val iter = tRowSet.getRowsIterator
Expand Down Expand Up @@ -177,15 +186,18 @@ class RowSetSuite extends KyuubiFunSuite with RowSetHelper {
val r8 = iter.next().getColVals
assert(r8.get(12).getStringVal.getValue === Array.fill(7)(7.7d).mkString("[", ",", "]"))
assert(r8.get(13).getStringVal.getValue ===
RowSet.toHiveString(Map(7 -> 7.7d) -> MapType(IntegerType, DoubleType)))
RowSet.toHiveString(
Map(7 -> 7.7d) -> MapType(IntegerType, DoubleType),
timeFormatters = timeFormatters))

val r9 = iter.next().getColVals
assert(r9.get(14).getStringVal.getValue === new CalendarInterval(8, 8, 8).toString)
}

test("to row set") {
val timeFormatters = HiveResult.getTimeFormatters
TProtocolVersion.values().foreach { proto =>
val set = RowSet.toTRowSet(rows, schema, proto)
val set = RowSet.toTRowSet(rows, schema, proto, timeFormatters)
if (proto.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
assert(!set.isSetColumns, proto.toString)
assert(set.isSetRows, proto.toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.commons.lang3.time.StopWatch
import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.hive.service.rpc.thrift.TProtocolVersion._
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.types._

import org.apache.kyuubi.KyuubiFunSuite
Expand All @@ -37,7 +38,7 @@ import org.apache.kyuubi.KyuubiFunSuite
* }}}
*/
class TRowSetBenchmark extends KyuubiFunSuite with RowSetHelper {
private val runBenchmark = sys.env.contains("RUN_BENCHMARK")
private val runBenchmark = sys.env.contains("RUN_BENCHMARK") || true

private val rowCount = 3000
private lazy val allRows = (0 until rowCount).map(genRow)
Expand Down Expand Up @@ -94,7 +95,8 @@ class TRowSetBenchmark extends KyuubiFunSuite with RowSetHelper {
schema: StructType,
protocolVersion: TProtocolVersion): BigDecimal = {
val sw = StopWatch.createStarted()
RowSet.toTRowSet(rows, schema, protocolVersion)
val timeFormatters = HiveResult.getTimeFormatters
RowSet.toTRowSet(rows, schema, protocolVersion, timeFormatters)
sw.stop()
val msTimeCost: BigDecimal = (BigDecimal(sw.getNanoTime) / BigDecimal(1000000))
.setScale(3, RoundingMode.HALF_UP)
Expand Down

0 comments on commit f35ba84

Please sign in to comment.