diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala index 4f935ce49f0..e8670b98246 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala @@ -24,15 +24,17 @@ import scala.collection.JavaConverters._ import org.apache.hive.service.rpc.thrift._ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.HiveResult +import org.apache.spark.sql.execution.HiveResult.TimeFormatters import org.apache.spark.sql.types._ import org.apache.kyuubi.util.RowSetUtils._ object RowSet { - def toHiveString(valueAndType: (Any, DataType), nested: Boolean = false): String = { - // compatible w/ Spark 3.1 and above - val timeFormatters = HiveResult.getTimeFormatters + def toHiveString( + valueAndType: (Any, DataType), + nested: Boolean = false, + timeFormatters: TimeFormatters): String = { HiveResult.toHiveString(valueAndType, nested, timeFormatters) } @@ -71,6 +73,7 @@ object RowSet { def toRowBasedSet(rows: Seq[Row], schema: StructType): TRowSet = { val rowSize = rows.length val tRows = new java.util.ArrayList[TRow](rowSize) + val timeFormatters = HiveResult.getTimeFormatters var i = 0 while (i < rowSize) { val row = rows(i) @@ -78,7 +81,7 @@ object RowSet { var j = 0 val columnSize = row.length while (j < columnSize) { - val columnValue = toTColumnValue(j, row, schema) + val columnValue = toTColumnValue(j, row, schema, timeFormatters) tRow.addToColVals(columnValue) j += 1 } @@ -91,18 +94,23 @@ object RowSet { def toColumnBasedSet(rows: Seq[Row], schema: StructType): TRowSet = { val rowSize = rows.length val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize)) + val timeFormatters = HiveResult.getTimeFormatters var i = 0 val columnSize = schema.length while (i < columnSize) { val field = schema(i) - val tColumn = toTColumn(rows, i, field.dataType) + val tColumn = toTColumn(rows, i, field.dataType, timeFormatters) tRowSet.addToColumns(tColumn) i += 1 } tRowSet } - private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType): TColumn = { + private def toTColumn( + rows: Seq[Row], + ordinal: Int, + typ: DataType, + timeFormatters: TimeFormatters): TColumn = { val nulls = new java.util.BitSet() typ match { case BooleanType => @@ -152,7 +160,7 @@ object RowSet { while (i < rowSize) { val row = rows(i) nulls.set(i, row.isNullAt(ordinal)) - values.add(toHiveString(row.get(ordinal) -> typ)) + values.add(toHiveString(row.get(ordinal) -> typ, timeFormatters = timeFormatters)) i += 1 } TColumn.stringVal(new TStringColumn(values, nulls)) @@ -184,7 +192,8 @@ object RowSet { private def toTColumnValue( ordinal: Int, row: Row, - types: StructType): TColumnValue = { + types: StructType, + timeFormatters: TimeFormatters): TColumnValue = { types(ordinal).dataType match { case BooleanType => val boolValue = new TBoolValue @@ -232,7 +241,9 @@ object RowSet { case _ => val tStrValue = new TStringValue if (!row.isNullAt(ordinal)) { - tStrValue.setValue(toHiveString(row.get(ordinal) -> types(ordinal).dataType)) + tStrValue.setValue(toHiveString( + row.get(ordinal) -> types(ordinal).dataType, + timeFormatters = timeFormatters)) } TColumnValue.stringVal(tStrValue) } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index c0f9d61c210..e8431226876 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.HiveResult import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -105,17 +106,31 @@ object SparkDatasetHelper extends Logging { val quotedCol = (name: String) => col(quoteIfNeeded(name)) // an udf to call `RowSet.toHiveString` on complex types(struct/array/map) and timestamp type. + // TODO: reuse the timeFormatters on greater scale if possible, + // recreating timeFormatters may cause performance issue, see [KYUUBI#5811] val toHiveStringUDF = udf[String, Row, String]((row, schemaDDL) => { val dt = DataType.fromDDL(schemaDDL) dt match { case StructType(Array(StructField(_, st: StructType, _, _))) => - RowSet.toHiveString((row, st), nested = true) + RowSet.toHiveString( + (row, st), + nested = true, + timeFormatters = HiveResult.getTimeFormatters) case StructType(Array(StructField(_, at: ArrayType, _, _))) => - RowSet.toHiveString((row.toSeq.head, at), nested = true) + RowSet.toHiveString( + (row.toSeq.head, at), + nested = true, + timeFormatters = HiveResult.getTimeFormatters) case StructType(Array(StructField(_, mt: MapType, _, _))) => - RowSet.toHiveString((row.toSeq.head, mt), nested = true) + RowSet.toHiveString( + (row.toSeq.head, mt), + nested = true, + timeFormatters = HiveResult.getTimeFormatters) case StructType(Array(StructField(_, tt: TimestampType, _, _))) => - RowSet.toHiveString((row.toSeq.head, tt), nested = true) + RowSet.toHiveString( + (row.toSeq.head, tt), + nested = true, + timeFormatters = HiveResult.getTimeFormatters) case _ => throw new UnsupportedOperationException } diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala index 5d2ba4a0d11..5c140291092 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala @@ -26,6 +26,7 @@ import scala.collection.JavaConverters._ import org.apache.hive.service.rpc.thrift.TProtocolVersion import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.HiveResult import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -165,14 +166,18 @@ class RowSetSuite extends KyuubiFunSuite { dateCol.getValues.asScala.zipWithIndex.foreach { case (b, 11) => assert(b === "NULL") case (b, i) => - assert(b === RowSet.toHiveString(Date.valueOf(s"2018-11-${i + 1}") -> DateType)) + assert(b === RowSet.toHiveString( + Date.valueOf(s"2018-11-${i + 1}") -> DateType, + timeFormatters = HiveResult.getTimeFormatters)) } val tsCol = cols.next().getStringVal tsCol.getValues.asScala.zipWithIndex.foreach { case (b, 11) => assert(b === "NULL") case (b, i) => assert(b === - RowSet.toHiveString(Timestamp.valueOf(s"2018-11-17 13:33:33.$i") -> TimestampType)) + RowSet.toHiveString( + Timestamp.valueOf(s"2018-11-17 13:33:33.$i") -> TimestampType, + timeFormatters = HiveResult.getTimeFormatters)) } val binCol = cols.next().getBinaryVal @@ -185,14 +190,16 @@ class RowSetSuite extends KyuubiFunSuite { arrCol.getValues.asScala.zipWithIndex.foreach { case (b, 11) => assert(b === "NULL") case (b, i) => assert(b === RowSet.toHiveString( - Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq -> ArrayType(DoubleType))) + Array.fill(i)(java.lang.Double.valueOf(s"$i.$i")).toSeq -> ArrayType(DoubleType), + timeFormatters = HiveResult.getTimeFormatters)) } val mapCol = cols.next().getStringVal mapCol.getValues.asScala.zipWithIndex.foreach { case (b, 11) => assert(b === "NULL") case (b, i) => assert(b === RowSet.toHiveString( - Map(i -> java.lang.Double.valueOf(s"$i.$i")) -> MapType(IntegerType, DoubleType))) + Map(i -> java.lang.Double.valueOf(s"$i.$i")) -> MapType(IntegerType, DoubleType), + timeFormatters = HiveResult.getTimeFormatters)) } val intervalCol = cols.next().getStringVal @@ -241,7 +248,9 @@ class RowSetSuite extends KyuubiFunSuite { val r8 = iter.next().getColVals assert(r8.get(12).getStringVal.getValue === Array.fill(7)(7.7d).mkString("[", ",", "]")) assert(r8.get(13).getStringVal.getValue === - RowSet.toHiveString(Map(7 -> 7.7d) -> MapType(IntegerType, DoubleType))) + RowSet.toHiveString( + Map(7 -> 7.7d) -> MapType(IntegerType, DoubleType), + timeFormatters = HiveResult.getTimeFormatters)) val r9 = iter.next().getColVals assert(r9.get(14).getStringVal.getValue === new CalendarInterval(8, 8, 8).toString)