diff --git a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/operation/ChatOperation.scala b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/operation/ChatOperation.scala index 9cddc3e66ab..60f15ea6534 100644 --- a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/operation/ChatOperation.scala +++ b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/operation/ChatOperation.scala @@ -18,7 +18,8 @@ package org.apache.kyuubi.engine.chat.operation import org.apache.kyuubi.{KyuubiSQLException, Utils} import org.apache.kyuubi.config.KyuubiConf -import org.apache.kyuubi.engine.chat.schema.{RowSet, SchemaHelper} +import org.apache.kyuubi.engine.chat.schema.{ChatTRowSetGenerator, SchemaHelper} +import org.apache.kyuubi.engine.chat.schema.ChatTRowSetGenerator.COL_STRING_TYPE import org.apache.kyuubi.operation.{AbstractOperation, FetchIterator, OperationState} import org.apache.kyuubi.operation.FetchOrientation.{FETCH_FIRST, FETCH_NEXT, FETCH_PRIOR, FetchOrientation} import org.apache.kyuubi.session.Session @@ -45,8 +46,11 @@ abstract class ChatOperation(session: Session) extends AbstractOperation(session iter.fetchAbsolute(0) } - val taken = iter.take(rowSetSize) - val resultRowSet = RowSet.toTRowSet(taken.toSeq, 1, getProtocolVersion) + val taken = iter.take(rowSetSize).map(_.toSeq) + val resultRowSet = new ChatTRowSetGenerator().toTRowSet( + taken.toSeq, + Seq(COL_STRING_TYPE), + getProtocolVersion) resultRowSet.setStartRowOffset(iter.getPosition) val resp = new TFetchResultsResp(OK_STATUS) resp.setResults(resultRowSet) diff --git a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/schema/ChatTRowSetGenerator.scala b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/schema/ChatTRowSetGenerator.scala new file mode 100644 index 00000000000..990a1976480 --- /dev/null +++ b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/schema/ChatTRowSetGenerator.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.chat.schema + +import org.apache.kyuubi.engine.chat.schema.ChatTRowSetGenerator._ +import org.apache.kyuubi.engine.schema.AbstractTRowSetGenerator +import org.apache.kyuubi.shaded.hive.service.rpc.thrift._ +import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TTypeId._ + +class ChatTRowSetGenerator + extends AbstractTRowSetGenerator[Seq[String], Seq[String], String] { + + override def getColumnSizeFromSchemaType(schema: Seq[String]): Int = schema.length + + override def getColumnType(schema: Seq[String], ordinal: Int): String = COL_STRING_TYPE + + override protected def isColumnNullAt(row: Seq[String], ordinal: Int): Boolean = + row(ordinal) == null + + override def getColumnAs[T](row: Seq[String], ordinal: Int): T = row(ordinal).asInstanceOf[T] + + override def toTColumn(rows: Seq[Seq[String]], ordinal: Int, typ: String): TColumn = + typ match { + case COL_STRING_TYPE => toTTypeColumn(STRING_TYPE, rows, ordinal) + case otherType => throw new UnsupportedOperationException(s"type $otherType") + } + + override def toTColumnValue(ordinal: Int, row: Seq[String], types: Seq[String]): TColumnValue = + getColumnType(types, ordinal) match { + case "String" => toTTypeColumnVal(STRING_TYPE, row, ordinal) + case otherType => throw new UnsupportedOperationException(s"type $otherType") + } +} + +object ChatTRowSetGenerator { + val COL_STRING_TYPE: String = classOf[String].getSimpleName +} diff --git a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/schema/RowSet.scala b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/schema/RowSet.scala deleted file mode 100644 index 82794000160..00000000000 --- a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/schema/RowSet.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.kyuubi.engine.chat.schema - -import java.util - -import org.apache.kyuubi.shaded.hive.service.rpc.thrift._ -import org.apache.kyuubi.util.RowSetUtils._ - -object RowSet { - - def emptyTRowSet(): TRowSet = { - new TRowSet(0, new java.util.ArrayList[TRow](0)) - } - - def toTRowSet( - rows: Seq[Array[String]], - columnSize: Int, - protocolVersion: TProtocolVersion): TRowSet = { - if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) { - toRowBasedSet(rows, columnSize) - } else { - toColumnBasedSet(rows, columnSize) - } - } - - def toRowBasedSet(rows: Seq[Array[String]], columnSize: Int): TRowSet = { - val rowSize = rows.length - val tRows = new java.util.ArrayList[TRow](rowSize) - var i = 0 - while (i < rowSize) { - val row = rows(i) - val tRow = new TRow() - var j = 0 - val columnSize = row.length - while (j < columnSize) { - val columnValue = stringTColumnValue(j, row) - tRow.addToColVals(columnValue) - j += 1 - } - i += 1 - tRows.add(tRow) - } - new TRowSet(0, tRows) - } - - def toColumnBasedSet(rows: Seq[Array[String]], columnSize: Int): TRowSet = { - val rowSize = rows.length - val tRowSet = new TRowSet(0, new util.ArrayList[TRow](rowSize)) - var i = 0 - while (i < columnSize) { - val tColumn = toTColumn(rows, i) - tRowSet.addToColumns(tColumn) - i += 1 - } - tRowSet - } - - private def toTColumn(rows: Seq[Array[String]], ordinal: Int): TColumn = { - val nulls = new java.util.BitSet() - val values = getOrSetAsNull[String](rows, ordinal, nulls, "") - TColumn.stringVal(new TStringColumn(values, nulls)) - } - - private def getOrSetAsNull[String]( - rows: Seq[Array[String]], - ordinal: Int, - nulls: util.BitSet, - defaultVal: String): util.List[String] = { - val size = rows.length - val ret = new util.ArrayList[String](size) - var idx = 0 - while (idx < size) { - val row = rows(idx) - val isNull = row(ordinal) == null - if (isNull) { - nulls.set(idx, true) - ret.add(idx, defaultVal) - } else { - ret.add(idx, row(ordinal)) - } - idx += 1 - } - ret - } - - private def stringTColumnValue(ordinal: Int, row: Array[String]): TColumnValue = { - val tStringValue = new TStringValue - if (row(ordinal) != null) tStringValue.setValue(row(ordinal)) - TColumnValue.stringVal(tStringValue) - } -} diff --git a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/operation/FlinkOperation.scala b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/operation/FlinkOperation.scala index ff2e99c0c1b..df067a888c6 100644 --- a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/operation/FlinkOperation.scala +++ b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/operation/FlinkOperation.scala @@ -31,7 +31,7 @@ import org.apache.flink.types.Row import org.apache.kyuubi.{KyuubiSQLException, Utils} import org.apache.kyuubi.engine.flink.result.ResultSet -import org.apache.kyuubi.engine.flink.schema.RowSet +import org.apache.kyuubi.engine.flink.schema.{FlinkTRowSetGenerator, RowSet} import org.apache.kyuubi.engine.flink.session.FlinkSessionImpl import org.apache.kyuubi.operation.{AbstractOperation, OperationState} import org.apache.kyuubi.operation.FetchOrientation.{FETCH_FIRST, FETCH_NEXT, FETCH_PRIOR, FetchOrientation} @@ -133,10 +133,9 @@ abstract class FlinkOperation(session: Session) extends AbstractOperation(sessio case Some(tz) => ZoneId.of(tz) case None => ZoneId.systemDefault() } - val resultRowSet = RowSet.resultSetToTRowSet( + val resultRowSet = new FlinkTRowSetGenerator(zoneId).toTRowSet( batch.toList, resultSet, - zoneId, getProtocolVersion) val resp = new TFetchResultsResp(OK_STATUS) resp.setResults(resultRowSet) diff --git a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/FlinkTRowSetGenerator.scala b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/FlinkTRowSetGenerator.scala new file mode 100644 index 00000000000..b53aab47fb4 --- /dev/null +++ b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/FlinkTRowSetGenerator.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kyuubi.engine.flink.schema + +import java.time.{Instant, ZonedDateTime, ZoneId} + +import scala.collection.JavaConverters._ + +import org.apache.flink.table.data.StringData +import org.apache.flink.table.types.logical._ +import org.apache.flink.types.Row + +import org.apache.kyuubi.engine.flink.result.ResultSet +import org.apache.kyuubi.engine.flink.schema.RowSet.{toHiveString, TIMESTAMP_LZT_FORMATTER} +import org.apache.kyuubi.engine.schema.AbstractTRowSetGenerator +import org.apache.kyuubi.shaded.hive.service.rpc.thrift._ +import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TTypeId._ +import org.apache.kyuubi.util.RowSetUtils.bitSetToBuffer + +class FlinkTRowSetGenerator(zoneId: ZoneId) + extends AbstractTRowSetGenerator[ResultSet, Row, LogicalType] { + override def getColumnSizeFromSchemaType(schema: ResultSet): Int = schema.columns.size + + override def getColumnType(schema: ResultSet, ordinal: Int): LogicalType = + schema.columns.get(ordinal).getDataType.getLogicalType + + override def isColumnNullAt(row: Row, ordinal: Int): Boolean = row.getField(ordinal) == null + + override def getColumnAs[T](row: Row, ordinal: Int): T = row.getFieldAs[T](ordinal) + + override def toTColumnValue(ordinal: Int, row: Row, types: ResultSet): TColumnValue = { + getColumnType(types, ordinal) match { + case _: BooleanType => toTTypeColumnVal(BOOLEAN_TYPE, row, ordinal) + case _: TinyIntType => toTTypeColumnVal(BINARY_TYPE, row, ordinal) + case _: SmallIntType => toTTypeColumnVal(TINYINT_TYPE, row, ordinal) + case _: IntType => toTTypeColumnVal(INT_TYPE, row, ordinal) + case _: BigIntType => toTTypeColumnVal(BIGINT_TYPE, row, ordinal) + case _: DoubleType => toTTypeColumnVal(DOUBLE_TYPE, row, ordinal) + case _: FloatType => toTTypeColumnVal(FLOAT_TYPE, row, ordinal) + case t @ (_: VarCharType | _: CharType) => + val tStringValue = new TStringValue + val fieldValue = row.getField(ordinal) + fieldValue match { + case value: String => + tStringValue.setValue(value) + case value: StringData => + tStringValue.setValue(value.toString) + case null => + tStringValue.setValue(null) + case other => + throw new IllegalArgumentException( + s"Unsupported conversion class ${other.getClass} " + + s"for type ${t.getClass}.") + } + TColumnValue.stringVal(tStringValue) + case _: LocalZonedTimestampType => + val tStringValue = new TStringValue + val fieldValue = row.getField(ordinal) + tStringValue.setValue(TIMESTAMP_LZT_FORMATTER.format( + ZonedDateTime.ofInstant(fieldValue.asInstanceOf[Instant], zoneId))) + TColumnValue.stringVal(tStringValue) + case t => + val tStringValue = new TStringValue + if (row.getField(ordinal) != null) { + tStringValue.setValue(toHiveString((row.getField(ordinal), t))) + } + TColumnValue.stringVal(tStringValue) + } + } + + override def toTColumn(rows: Seq[Row], ordinal: Int, logicalType: LogicalType): TColumn = { + val nulls = new java.util.BitSet() + // for each column, determine the conversion class by sampling the first non-value value + // if there's no row, set the entire column empty + val sampleField = rows.iterator.map(_.getField(ordinal)).find(_ ne null).orNull + logicalType match { + case _: BooleanType => toTTypeColumn(BOOLEAN_TYPE, rows, ordinal) + case _: TinyIntType => toTTypeColumn(BINARY_TYPE, rows, ordinal) + case _: SmallIntType => toTTypeColumn(TINYINT_TYPE, rows, ordinal) + case _: IntType => toTTypeColumn(INT_TYPE, rows, ordinal) + case _: BigIntType => toTTypeColumn(BIGINT_TYPE, rows, ordinal) + case _: FloatType => toTTypeColumn(FLOAT_TYPE, rows, ordinal) + case _: DoubleType => toTTypeColumn(DOUBLE_TYPE, rows, ordinal) + case t @ (_: VarCharType | _: CharType) => + val values: java.util.List[String] = new java.util.ArrayList[String](0) + sampleField match { + case _: String => + values.addAll(getOrSetAsNull[String](rows, ordinal, nulls, "")) + case _: StringData => + val stringDataValues = + getOrSetAsNull[StringData](rows, ordinal, nulls, StringData.fromString("")) + stringDataValues.forEach(e => values.add(e.toString)) + case null => + values.addAll(getOrSetAsNull[String](rows, ordinal, nulls, "")) + case other => + throw new IllegalArgumentException( + s"Unsupported conversion class ${other.getClass} " + + s"for type ${t.getClass}.") + } + TColumn.stringVal(new TStringColumn(values, nulls)) + case _: LocalZonedTimestampType => + val values = getOrSetAsNull[Instant](rows, ordinal, nulls, Instant.EPOCH) + .toArray().map(v => + TIMESTAMP_LZT_FORMATTER.format( + ZonedDateTime.ofInstant(v.asInstanceOf[Instant], zoneId))) + TColumn.stringVal(new TStringColumn(values.toList.asJava, nulls)) + case _ => + var i = 0 + val rowSize = rows.length + val values = new java.util.ArrayList[String](rowSize) + while (i < rowSize) { + val row = rows(i) + nulls.set(i, row.getField(ordinal) == null) + val value = + if (row.getField(ordinal) == null) { + "" + } else { + toHiveString((row.getField(ordinal), logicalType)) + } + values.add(value) + i += 1 + } + TColumn.stringVal(new TStringColumn(values, nulls)) + } + } + +} diff --git a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/RowSet.scala b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/RowSet.scala index a000869cc5b..7015d7c52b6 100644 --- a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/RowSet.scala +++ b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/RowSet.scala @@ -17,262 +17,25 @@ package org.apache.kyuubi.engine.flink.schema -import java.{lang, util} -import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.time.{Instant, LocalDate, LocalDateTime, ZonedDateTime, ZoneId} +import java.time.{LocalDate, LocalDateTime} import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder, TextStyle} import java.time.temporal.ChronoField import java.util.Collections import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer -import scala.language.implicitConversions import org.apache.flink.table.catalog.Column -import org.apache.flink.table.data.StringData import org.apache.flink.table.types.logical._ import org.apache.flink.types.Row -import org.apache.kyuubi.engine.flink.result.ResultSet import org.apache.kyuubi.shaded.hive.service.rpc.thrift._ import org.apache.kyuubi.util.RowSetUtils._ object RowSet { - def resultSetToTRowSet( - rows: Seq[Row], - resultSet: ResultSet, - zoneId: ZoneId, - protocolVersion: TProtocolVersion): TRowSet = { - if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) { - toRowBaseSet(rows, resultSet, zoneId) - } else { - toColumnBasedSet(rows, resultSet, zoneId) - } - } - - def toRowBaseSet(rows: Seq[Row], resultSet: ResultSet, zoneId: ZoneId): TRowSet = { - val rowSize = rows.size - val tRows = new util.ArrayList[TRow](rowSize) - var i = 0 - while (i < rowSize) { - val row = rows(i) - val tRow = new TRow() - val columnSize = row.getArity - var j = 0 - while (j < columnSize) { - val columnValue = toTColumnValue(j, row, resultSet, zoneId) - tRow.addToColVals(columnValue) - j += 1 - } - tRows.add(tRow) - i += 1 - } - - new TRowSet(0, tRows) - } - - def toColumnBasedSet(rows: Seq[Row], resultSet: ResultSet, zoneId: ZoneId): TRowSet = { - val size = rows.length - val tRowSet = new TRowSet(0, new util.ArrayList[TRow](size)) - val columnSize = resultSet.getColumns.size() - var i = 0 - while (i < columnSize) { - val field = resultSet.getColumns.get(i) - val tColumn = toTColumn(rows, i, field.getDataType.getLogicalType, zoneId) - tRowSet.addToColumns(tColumn) - i += 1 - } - tRowSet - } - - private def toTColumnValue( - ordinal: Int, - row: Row, - resultSet: ResultSet, - zoneId: ZoneId): TColumnValue = { - - val column = resultSet.getColumns.get(ordinal) - val logicalType = column.getDataType.getLogicalType - - logicalType match { - case _: BooleanType => - val boolValue = new TBoolValue - if (row.getField(ordinal) != null) { - boolValue.setValue(row.getField(ordinal).asInstanceOf[Boolean]) - } - TColumnValue.boolVal(boolValue) - case _: TinyIntType => - val tByteValue = new TByteValue - if (row.getField(ordinal) != null) { - tByteValue.setValue(row.getField(ordinal).asInstanceOf[Byte]) - } - TColumnValue.byteVal(tByteValue) - case _: SmallIntType => - val tI16Value = new TI16Value - if (row.getField(ordinal) != null) { - tI16Value.setValue(row.getField(ordinal).asInstanceOf[Short]) - } - TColumnValue.i16Val(tI16Value) - case _: IntType => - val tI32Value = new TI32Value - if (row.getField(ordinal) != null) { - tI32Value.setValue(row.getField(ordinal).asInstanceOf[Int]) - } - TColumnValue.i32Val(tI32Value) - case _: BigIntType => - val tI64Value = new TI64Value - if (row.getField(ordinal) != null) { - tI64Value.setValue(row.getField(ordinal).asInstanceOf[Long]) - } - TColumnValue.i64Val(tI64Value) - case _: FloatType => - val tDoubleValue = new TDoubleValue - if (row.getField(ordinal) != null) { - val doubleValue = lang.Double.valueOf(row.getField(ordinal).asInstanceOf[Float].toString) - tDoubleValue.setValue(doubleValue) - } - TColumnValue.doubleVal(tDoubleValue) - case _: DoubleType => - val tDoubleValue = new TDoubleValue - if (row.getField(ordinal) != null) { - tDoubleValue.setValue(row.getField(ordinal).asInstanceOf[Double]) - } - TColumnValue.doubleVal(tDoubleValue) - case t @ (_: VarCharType | _: CharType) => - val tStringValue = new TStringValue - val fieldValue = row.getField(ordinal) - fieldValue match { - case value: String => - tStringValue.setValue(value) - case value: StringData => - tStringValue.setValue(value.toString) - case null => - tStringValue.setValue(null) - case other => - throw new IllegalArgumentException( - s"Unsupported conversion class ${other.getClass} " + - s"for type ${t.getClass}.") - } - TColumnValue.stringVal(tStringValue) - case _: LocalZonedTimestampType => - val tStringValue = new TStringValue - val fieldValue = row.getField(ordinal) - tStringValue.setValue(TIMESTAMP_LZT_FORMATTER.format( - ZonedDateTime.ofInstant(fieldValue.asInstanceOf[Instant], zoneId))) - TColumnValue.stringVal(tStringValue) - case t => - val tStringValue = new TStringValue - if (row.getField(ordinal) != null) { - tStringValue.setValue(toHiveString((row.getField(ordinal), t))) - } - TColumnValue.stringVal(tStringValue) - } - } - - implicit private def bitSetToBuffer(bitSet: java.util.BitSet): ByteBuffer = { - ByteBuffer.wrap(bitSet.toByteArray) - } - - private def toTColumn( - rows: Seq[Row], - ordinal: Int, - logicalType: LogicalType, - zoneId: ZoneId): TColumn = { - val nulls = new java.util.BitSet() - // for each column, determine the conversion class by sampling the first non-value value - // if there's no row, set the entire column empty - val sampleField = rows.iterator.map(_.getField(ordinal)).find(_ ne null).orNull - logicalType match { - case _: BooleanType => - val values = getOrSetAsNull[lang.Boolean](rows, ordinal, nulls, true) - TColumn.boolVal(new TBoolColumn(values, nulls)) - case _: TinyIntType => - val values = getOrSetAsNull[lang.Byte](rows, ordinal, nulls, 0.toByte) - TColumn.byteVal(new TByteColumn(values, nulls)) - case _: SmallIntType => - val values = getOrSetAsNull[lang.Short](rows, ordinal, nulls, 0.toShort) - TColumn.i16Val(new TI16Column(values, nulls)) - case _: IntType => - val values = getOrSetAsNull[lang.Integer](rows, ordinal, nulls, 0) - TColumn.i32Val(new TI32Column(values, nulls)) - case _: BigIntType => - val values = getOrSetAsNull[lang.Long](rows, ordinal, nulls, 0L) - TColumn.i64Val(new TI64Column(values, nulls)) - case _: FloatType => - val values = getOrSetAsNull[lang.Float](rows, ordinal, nulls, 0.0f) - .asScala.map(n => lang.Double.valueOf(n.toString)).asJava - TColumn.doubleVal(new TDoubleColumn(values, nulls)) - case _: DoubleType => - val values = getOrSetAsNull[lang.Double](rows, ordinal, nulls, 0.0) - TColumn.doubleVal(new TDoubleColumn(values, nulls)) - case t @ (_: VarCharType | _: CharType) => - val values: util.List[String] = new util.ArrayList[String](0) - sampleField match { - case _: String => - values.addAll(getOrSetAsNull[String](rows, ordinal, nulls, "")) - case _: StringData => - val stringDataValues = - getOrSetAsNull[StringData](rows, ordinal, nulls, StringData.fromString("")) - stringDataValues.forEach(e => values.add(e.toString)) - case null => - values.addAll(getOrSetAsNull[String](rows, ordinal, nulls, "")) - case other => - throw new IllegalArgumentException( - s"Unsupported conversion class ${other.getClass} " + - s"for type ${t.getClass}.") - } - TColumn.stringVal(new TStringColumn(values, nulls)) - case _: LocalZonedTimestampType => - val values = getOrSetAsNull[Instant](rows, ordinal, nulls, Instant.EPOCH) - .toArray().map(v => - TIMESTAMP_LZT_FORMATTER.format( - ZonedDateTime.ofInstant(v.asInstanceOf[Instant], zoneId))) - TColumn.stringVal(new TStringColumn(values.toList.asJava, nulls)) - case _ => - var i = 0 - val rowSize = rows.length - val values = new java.util.ArrayList[String](rowSize) - while (i < rowSize) { - val row = rows(i) - nulls.set(i, row.getField(ordinal) == null) - val value = - if (row.getField(ordinal) == null) { - "" - } else { - toHiveString((row.getField(ordinal), logicalType)) - } - values.add(value) - i += 1 - } - TColumn.stringVal(new TStringColumn(values, nulls)) - } - } - - private def getOrSetAsNull[T]( - rows: Seq[Row], - ordinal: Int, - nulls: java.util.BitSet, - defaultVal: T): java.util.List[T] = { - val size = rows.length - val ret = new java.util.ArrayList[T](size) - var idx = 0 - while (idx < size) { - val row = rows(idx) - val isNull = row.getField(ordinal) == null - if (isNull) { - nulls.set(idx, true) - ret.add(idx, defaultVal) - } else { - ret.add(idx, row.getFieldAs[T](ordinal)) - } - idx += 1 - } - ret - } - def toTColumnDesc(field: Column, pos: Int): TColumnDesc = { val tColumnDesc = new TColumnDesc() tColumnDesc.setColumnName(field.getName) diff --git a/externals/kyuubi-flink-sql-engine/src/test/scala/org/apache/kyuubi/engine/flink/result/ResultSetSuite.scala b/externals/kyuubi-flink-sql-engine/src/test/scala/org/apache/kyuubi/engine/flink/result/ResultSetSuite.scala index 9ee5c658bc9..5e58d433f91 100644 --- a/externals/kyuubi-flink-sql-engine/src/test/scala/org/apache/kyuubi/engine/flink/result/ResultSetSuite.scala +++ b/externals/kyuubi-flink-sql-engine/src/test/scala/org/apache/kyuubi/engine/flink/result/ResultSetSuite.scala @@ -25,7 +25,7 @@ import org.apache.flink.table.data.StringData import org.apache.flink.types.Row import org.apache.kyuubi.KyuubiFunSuite -import org.apache.kyuubi.engine.flink.schema.RowSet +import org.apache.kyuubi.engine.flink.schema.FlinkTRowSetGenerator class ResultSetSuite extends KyuubiFunSuite { @@ -47,9 +47,9 @@ class ResultSetSuite extends KyuubiFunSuite { .build val timeZone = ZoneId.of("America/Los_Angeles") - assert(RowSet.toRowBaseSet(rowsNew, resultSetNew, timeZone) - === RowSet.toRowBaseSet(rowsOld, resultSetOld, timeZone)) - assert(RowSet.toColumnBasedSet(rowsNew, resultSetNew, timeZone) - === RowSet.toColumnBasedSet(rowsOld, resultSetOld, timeZone)) + assert(new FlinkTRowSetGenerator(timeZone).toRowBasedSet(rowsNew, resultSetNew) + === new FlinkTRowSetGenerator(timeZone).toRowBasedSet(rowsOld, resultSetOld)) + assert(new FlinkTRowSetGenerator(timeZone).toColumnBasedSet(rowsNew, resultSetNew) + === new FlinkTRowSetGenerator(timeZone).toColumnBasedSet(rowsOld, resultSetOld)) } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala index e2fc80c6b87..1d271cfcec8 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala @@ -24,7 +24,7 @@ import org.apache.spark.kyuubi.{SparkProgressMonitor, SQLOperationListener} import org.apache.spark.kyuubi.SparkUtilsHelper.redact import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{BinaryType, StructField, StructType} import org.apache.kyuubi.{KyuubiSQLException, Utils} import org.apache.kyuubi.config.KyuubiConf @@ -33,7 +33,7 @@ import org.apache.kyuubi.config.KyuubiReservedKeys.{KYUUBI_SESSION_SIGN_PUBLICKE import org.apache.kyuubi.engine.spark.KyuubiSparkUtil.{getSessionConf, SPARK_SCHEDULER_POOL_KEY} import org.apache.kyuubi.engine.spark.events.SparkOperationEvent import org.apache.kyuubi.engine.spark.operation.SparkOperation.TIMEZONE_KEY -import org.apache.kyuubi.engine.spark.schema.{RowSet, SchemaHelper} +import org.apache.kyuubi.engine.spark.schema.{SchemaHelper, SparkArrowTRowSetGenerator, SparkTRowSetGenerator} import org.apache.kyuubi.engine.spark.session.SparkSessionImpl import org.apache.kyuubi.events.EventBus import org.apache.kyuubi.operation.{AbstractOperation, FetchIterator, OperationState, OperationStatus} @@ -42,6 +42,7 @@ import org.apache.kyuubi.operation.OperationState.OperationState import org.apache.kyuubi.operation.log.OperationLog import org.apache.kyuubi.session.Session import org.apache.kyuubi.shaded.hive.service.rpc.thrift.{TFetchResultsResp, TGetResultSetMetadataResp, TProgressUpdateResp, TRowSet} +import org.apache.kyuubi.util.ThriftUtils abstract class SparkOperation(session: Session) extends AbstractOperation(session) { @@ -243,13 +244,16 @@ abstract class SparkOperation(session: Session) if (isArrowBasedOperation) { if (iter.hasNext) { val taken = iter.next().asInstanceOf[Array[Byte]] - RowSet.toTRowSet(taken, getProtocolVersion) + new SparkArrowTRowSetGenerator().toTRowSet( + Seq(taken), + new StructType().add(StructField(null, BinaryType)), + getProtocolVersion) } else { - RowSet.emptyTRowSet() + ThriftUtils.newEmptyRowSet } } else { val taken = iter.take(rowSetSize) - RowSet.toTRowSet( + new SparkTRowSetGenerator().toTRowSet( taken.toSeq.asInstanceOf[Seq[Row]], resultSchema, getProtocolVersion) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala index 806451907b1..c5f32210891 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala @@ -17,18 +17,10 @@ package org.apache.kyuubi.engine.spark.schema -import java.nio.ByteBuffer - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.Row import org.apache.spark.sql.execution.HiveResult import org.apache.spark.sql.execution.HiveResult.TimeFormatters import org.apache.spark.sql.types._ -import org.apache.kyuubi.shaded.hive.service.rpc.thrift._ -import org.apache.kyuubi.util.RowSetUtils._ - object RowSet { def toHiveString( @@ -38,224 +30,4 @@ object RowSet { HiveResult.toHiveString(valueAndType, nested, timeFormatters) } - def toTRowSet( - bytes: Array[Byte], - protocolVersion: TProtocolVersion): TRowSet = { - if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) { - throw new UnsupportedOperationException - } else { - toColumnBasedSet(bytes) - } - } - - def emptyTRowSet(): TRowSet = { - new TRowSet(0, new java.util.ArrayList[TRow](0)) - } - - def toColumnBasedSet(data: Array[Byte]): TRowSet = { - val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](1)) - val tColumn = toTColumn(data) - tRowSet.addToColumns(tColumn) - tRowSet - } - - def toTRowSet( - rows: Seq[Row], - schema: StructType, - protocolVersion: TProtocolVersion): TRowSet = { - if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) { - toRowBasedSet(rows, schema) - } else { - toColumnBasedSet(rows, schema) - } - } - - def toRowBasedSet(rows: Seq[Row], schema: StructType): TRowSet = { - val rowSize = rows.length - val tRows = new java.util.ArrayList[TRow](rowSize) - val timeFormatters = HiveResult.getTimeFormatters - var i = 0 - while (i < rowSize) { - val row = rows(i) - var j = 0 - val columnSize = row.length - val tColumnValues = new java.util.ArrayList[TColumnValue](columnSize) - while (j < columnSize) { - val columnValue = toTColumnValue(j, row, schema, timeFormatters) - tColumnValues.add(columnValue) - j += 1 - } - i += 1 - val tRow = new TRow(tColumnValues) - tRows.add(tRow) - } - new TRowSet(0, tRows) - } - - def toColumnBasedSet(rows: Seq[Row], schema: StructType): TRowSet = { - val rowSize = rows.length - val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize)) - val timeFormatters = HiveResult.getTimeFormatters - var i = 0 - val columnSize = schema.length - val tColumns = new java.util.ArrayList[TColumn](columnSize) - while (i < columnSize) { - val field = schema(i) - val tColumn = toTColumn(rows, i, field.dataType, timeFormatters) - tColumns.add(tColumn) - i += 1 - } - tRowSet.setColumns(tColumns) - tRowSet - } - - private def toTColumn( - rows: Seq[Row], - ordinal: Int, - typ: DataType, - timeFormatters: TimeFormatters): TColumn = { - val nulls = new java.util.BitSet() - typ match { - case BooleanType => - val values = getOrSetAsNull[java.lang.Boolean](rows, ordinal, nulls, true) - TColumn.boolVal(new TBoolColumn(values, nulls)) - - case ByteType => - val values = getOrSetAsNull[java.lang.Byte](rows, ordinal, nulls, 0.toByte) - TColumn.byteVal(new TByteColumn(values, nulls)) - - case ShortType => - val values = getOrSetAsNull[java.lang.Short](rows, ordinal, nulls, 0.toShort) - TColumn.i16Val(new TI16Column(values, nulls)) - - case IntegerType => - val values = getOrSetAsNull[java.lang.Integer](rows, ordinal, nulls, 0) - TColumn.i32Val(new TI32Column(values, nulls)) - - case LongType => - val values = getOrSetAsNull[java.lang.Long](rows, ordinal, nulls, 0L) - TColumn.i64Val(new TI64Column(values, nulls)) - - case FloatType => - val values = getOrSetAsNull[java.lang.Float](rows, ordinal, nulls, 0.toFloat) - .asScala.map(n => java.lang.Double.valueOf(n.toString)).asJava - TColumn.doubleVal(new TDoubleColumn(values, nulls)) - - case DoubleType => - val values = getOrSetAsNull[java.lang.Double](rows, ordinal, nulls, 0.toDouble) - TColumn.doubleVal(new TDoubleColumn(values, nulls)) - - case StringType => - val values = getOrSetAsNull[java.lang.String](rows, ordinal, nulls, "") - TColumn.stringVal(new TStringColumn(values, nulls)) - - case BinaryType => - val values = getOrSetAsNull[Array[Byte]](rows, ordinal, nulls, Array()) - .asScala - .map(ByteBuffer.wrap) - .asJava - TColumn.binaryVal(new TBinaryColumn(values, nulls)) - - case _ => - var i = 0 - val rowSize = rows.length - val values = new java.util.ArrayList[String](rowSize) - while (i < rowSize) { - val row = rows(i) - nulls.set(i, row.isNullAt(ordinal)) - values.add(toHiveString(row.get(ordinal) -> typ, timeFormatters = timeFormatters)) - i += 1 - } - TColumn.stringVal(new TStringColumn(values, nulls)) - } - } - - private def getOrSetAsNull[T]( - rows: Seq[Row], - ordinal: Int, - nulls: java.util.BitSet, - defaultVal: T): java.util.List[T] = { - val size = rows.length - val ret = new java.util.ArrayList[T](size) - var idx = 0 - while (idx < size) { - val row = rows(idx) - val isNull = row.isNullAt(ordinal) - if (isNull) { - nulls.set(idx, true) - ret.add(idx, defaultVal) - } else { - ret.add(idx, row.getAs[T](ordinal)) - } - idx += 1 - } - ret - } - - private def toTColumnValue( - ordinal: Int, - row: Row, - types: StructType, - timeFormatters: TimeFormatters): TColumnValue = { - types(ordinal).dataType match { - case BooleanType => - val boolValue = new TBoolValue - if (!row.isNullAt(ordinal)) boolValue.setValue(row.getBoolean(ordinal)) - TColumnValue.boolVal(boolValue) - - case ByteType => - val byteValue = new TByteValue - if (!row.isNullAt(ordinal)) byteValue.setValue(row.getByte(ordinal)) - TColumnValue.byteVal(byteValue) - - case ShortType => - val tI16Value = new TI16Value - if (!row.isNullAt(ordinal)) tI16Value.setValue(row.getShort(ordinal)) - TColumnValue.i16Val(tI16Value) - - case IntegerType => - val tI32Value = new TI32Value - if (!row.isNullAt(ordinal)) tI32Value.setValue(row.getInt(ordinal)) - TColumnValue.i32Val(tI32Value) - - case LongType => - val tI64Value = new TI64Value - if (!row.isNullAt(ordinal)) tI64Value.setValue(row.getLong(ordinal)) - TColumnValue.i64Val(tI64Value) - - case FloatType => - val tDoubleValue = new TDoubleValue - if (!row.isNullAt(ordinal)) { - val doubleValue = java.lang.Double.valueOf(row.getFloat(ordinal).toString) - tDoubleValue.setValue(doubleValue) - } - TColumnValue.doubleVal(tDoubleValue) - - case DoubleType => - val tDoubleValue = new TDoubleValue - if (!row.isNullAt(ordinal)) tDoubleValue.setValue(row.getDouble(ordinal)) - TColumnValue.doubleVal(tDoubleValue) - - case StringType => - val tStringValue = new TStringValue - if (!row.isNullAt(ordinal)) tStringValue.setValue(row.getString(ordinal)) - TColumnValue.stringVal(tStringValue) - - case _ => - val tStrValue = new TStringValue - if (!row.isNullAt(ordinal)) { - tStrValue.setValue(toHiveString( - row.get(ordinal) -> types(ordinal).dataType, - timeFormatters = timeFormatters)) - } - TColumnValue.stringVal(tStrValue) - } - } - - private def toTColumn(data: Array[Byte]): TColumn = { - val values = new java.util.ArrayList[ByteBuffer](1) - values.add(ByteBuffer.wrap(data)) - val nulls = new java.util.BitSet() - TColumn.binaryVal(new TBinaryColumn(values, nulls)) - } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/SparkArrowTRowSetGenerator.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/SparkArrowTRowSetGenerator.scala new file mode 100644 index 00000000000..ded022ad032 --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/SparkArrowTRowSetGenerator.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.spark.schema + +import java.nio.ByteBuffer + +import org.apache.spark.sql.types._ + +import org.apache.kyuubi.engine.schema.AbstractTRowSetGenerator +import org.apache.kyuubi.shaded.hive.service.rpc.thrift._ +import org.apache.kyuubi.util.RowSetUtils.bitSetToBuffer + +class SparkArrowTRowSetGenerator + extends AbstractTRowSetGenerator[StructType, Array[Byte], DataType] { + override def toColumnBasedSet(rows: Seq[Array[Byte]], schema: StructType): TRowSet = { + require(schema.length == 1, "ArrowRowSetGenerator accepts only one single byte array") + require(schema.head.dataType == BinaryType, "ArrowRowSetGenerator accepts only BinaryType") + + val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](1)) + val tColumn = toTColumn(rows, 1, schema.head.dataType) + tRowSet.addToColumns(tColumn) + tRowSet + } + + override def toTColumn(rows: Seq[Array[Byte]], ordinal: Int, typ: DataType): TColumn = { + require(rows.length == 1, "ArrowRowSetGenerator accepts only one single byte array") + typ match { + case BinaryType => + val values = new java.util.ArrayList[ByteBuffer](1) + values.add(ByteBuffer.wrap(rows.head)) + val nulls = new java.util.BitSet() + TColumn.binaryVal(new TBinaryColumn(values, nulls)) + case _ => throw new IllegalArgumentException( + s"unsupported datatype $typ, ArrowRowSetGenerator accepts only BinaryType") + } + } + + override def toRowBasedSet(rows: Seq[Array[Byte]], schema: StructType): TRowSet = { + throw new UnsupportedOperationException + } + + override def getColumnSizeFromSchemaType(schema: StructType): Int = { + throw new UnsupportedOperationException + } + + override def getColumnType(schema: StructType, ordinal: Int): DataType = { + throw new UnsupportedOperationException + } + + override def isColumnNullAt(row: Array[Byte], ordinal: Int): Boolean = { + throw new UnsupportedOperationException + } + + override def getColumnAs[T](row: Array[Byte], ordinal: Int): T = { + throw new UnsupportedOperationException + } + + override def toTColumnValue(ordinal: Int, row: Array[Byte], types: StructType): TColumnValue = { + throw new UnsupportedOperationException + } + +} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/SparkTRowSetGenerator.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/SparkTRowSetGenerator.scala new file mode 100644 index 00000000000..a35455292aa --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/SparkTRowSetGenerator.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.spark.schema + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.HiveResult +import org.apache.spark.sql.execution.HiveResult.TimeFormatters +import org.apache.spark.sql.types._ + +import org.apache.kyuubi.engine.schema.AbstractTRowSetGenerator +import org.apache.kyuubi.shaded.hive.service.rpc.thrift._ +import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TTypeId._ +import org.apache.kyuubi.util.RowSetUtils.bitSetToBuffer + +class SparkTRowSetGenerator + extends AbstractTRowSetGenerator[StructType, Row, DataType] { + + // reused time formatters in single RowSet generation, see KYUUBI-5811 + private val tf = HiveResult.getTimeFormatters + + override def getColumnSizeFromSchemaType(schema: StructType): Int = schema.length + + override def getColumnType(schema: StructType, ordinal: Int): DataType = schema(ordinal).dataType + + override def isColumnNullAt(row: Row, ordinal: Int): Boolean = row.isNullAt(ordinal) + + override def getColumnAs[T](row: Row, ordinal: Int): T = row.getAs[T](ordinal) + + override def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType): TColumn = { + val timeFormatters: TimeFormatters = tf + val nulls = new java.util.BitSet() + typ match { + case BooleanType => toTTypeColumn(BOOLEAN_TYPE, rows, ordinal) + case ByteType => toTTypeColumn(BINARY_TYPE, rows, ordinal) + case ShortType => toTTypeColumn(TINYINT_TYPE, rows, ordinal) + case IntegerType => toTTypeColumn(INT_TYPE, rows, ordinal) + case LongType => toTTypeColumn(BIGINT_TYPE, rows, ordinal) + case FloatType => toTTypeColumn(FLOAT_TYPE, rows, ordinal) + case DoubleType => toTTypeColumn(DOUBLE_TYPE, rows, ordinal) + case StringType => toTTypeColumn(STRING_TYPE, rows, ordinal) + case BinaryType => toTTypeColumn(ARRAY_TYPE, rows, ordinal) + case _ => + var i = 0 + val rowSize = rows.length + val values = new java.util.ArrayList[String](rowSize) + while (i < rowSize) { + val row = rows(i) + nulls.set(i, row.isNullAt(ordinal)) + values.add(RowSet.toHiveString(row.get(ordinal) -> typ, timeFormatters = timeFormatters)) + i += 1 + } + TColumn.stringVal(new TStringColumn(values, nulls)) + } + } + + override def toTColumnValue(ordinal: Int, row: Row, types: StructType): TColumnValue = { + val timeFormatters: TimeFormatters = tf + getColumnType(types, ordinal) match { + case BooleanType => toTTypeColumnVal(BOOLEAN_TYPE, row, ordinal) + case ByteType => toTTypeColumnVal(BINARY_TYPE, row, ordinal) + case ShortType => toTTypeColumnVal(TINYINT_TYPE, row, ordinal) + case IntegerType => toTTypeColumnVal(INT_TYPE, row, ordinal) + case LongType => toTTypeColumnVal(BIGINT_TYPE, row, ordinal) + case FloatType => toTTypeColumnVal(FLOAT_TYPE, row, ordinal) + case DoubleType => toTTypeColumnVal(DOUBLE_TYPE, row, ordinal) + case StringType => toTTypeColumnVal(STRING_TYPE, row, ordinal) + case _ => + val tStrValue = new TStringValue + if (!row.isNullAt(ordinal)) { + tStrValue.setValue(RowSet.toHiveString( + row.get(ordinal) -> types(ordinal).dataType, + timeFormatters = timeFormatters)) + } + TColumnValue.stringVal(tStrValue) + } + } + +} diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala index dec18589775..228bdcaf2c0 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala @@ -99,7 +99,7 @@ class RowSetSuite extends KyuubiFunSuite { private val rows: Seq[Row] = (0 to 10).map(genRow) ++ Seq(Row.fromSeq(Seq.fill(17)(null))) test("column based set") { - val tRowSet = RowSet.toColumnBasedSet(rows, schema) + val tRowSet = new SparkTRowSetGenerator().toColumnBasedSet(rows, schema) assert(tRowSet.getColumns.size() === schema.size) assert(tRowSet.getRowsSize === 0) @@ -210,7 +210,7 @@ class RowSetSuite extends KyuubiFunSuite { } test("row based set") { - val tRowSet = RowSet.toRowBasedSet(rows, schema) + val tRowSet = new SparkTRowSetGenerator().toRowBasedSet(rows, schema) assert(tRowSet.getColumnCount === 0) assert(tRowSet.getRowsSize === rows.size) val iter = tRowSet.getRowsIterator @@ -258,7 +258,7 @@ class RowSetSuite extends KyuubiFunSuite { test("to row set") { TProtocolVersion.values().foreach { proto => - val set = RowSet.toTRowSet(rows, schema, proto) + val set = new SparkTRowSetGenerator().toTRowSet(rows, schema, proto) if (proto.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) { assert(!set.isSetColumns, proto.toString) assert(set.isSetRows, proto.toString) diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala index 4f5049223c2..3de2ae59f42 100644 --- a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala @@ -22,7 +22,7 @@ import java.util.concurrent.RejectedExecutionException import org.apache.kyuubi.{KyuubiSQLException, Logging} import org.apache.kyuubi.engine.trino.TrinoStatement import org.apache.kyuubi.engine.trino.event.TrinoOperationEvent -import org.apache.kyuubi.engine.trino.schema.RowSet +import org.apache.kyuubi.engine.trino.schema.TrinoTRowSetGenerator import org.apache.kyuubi.events.EventBus import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator, OperationState} import org.apache.kyuubi.operation.FetchOrientation.{FETCH_FIRST, FETCH_NEXT, FETCH_PRIOR, FetchOrientation} @@ -96,7 +96,8 @@ class ExecuteStatement( throw KyuubiSQLException(s"Fetch orientation[$order] is not supported in $mode mode") } val taken = iter.take(rowSetSize) - val resultRowSet = RowSet.toTRowSet(taken.toList, schema, getProtocolVersion) + val resultRowSet = new TrinoTRowSetGenerator() + .toTRowSet(taken.toList, schema, getProtocolVersion) resultRowSet.setStartRowOffset(iter.getPosition) val fetchResultsResp = new TFetchResultsResp(OK_STATUS) fetchResultsResp.setResults(resultRowSet) diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala index d82b11adc05..822f1726a3b 100644 --- a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala @@ -25,8 +25,7 @@ import io.trino.client.StatementClient import org.apache.kyuubi.KyuubiSQLException import org.apache.kyuubi.Utils import org.apache.kyuubi.engine.trino.TrinoContext -import org.apache.kyuubi.engine.trino.schema.RowSet -import org.apache.kyuubi.engine.trino.schema.SchemaHelper +import org.apache.kyuubi.engine.trino.schema.{SchemaHelper, TrinoTRowSetGenerator} import org.apache.kyuubi.engine.trino.session.TrinoSessionImpl import org.apache.kyuubi.operation.AbstractOperation import org.apache.kyuubi.operation.FetchIterator @@ -66,7 +65,8 @@ abstract class TrinoOperation(session: Session) extends AbstractOperation(sessio case FETCH_FIRST => iter.fetchAbsolute(0) } val taken = iter.take(rowSetSize) - val resultRowSet = RowSet.toTRowSet(taken.toList, schema, getProtocolVersion) + val resultRowSet = + new TrinoTRowSetGenerator().toTRowSet(taken.toSeq, schema, getProtocolVersion) resultRowSet.setStartRowOffset(iter.getPosition) val resp = new TFetchResultsResp(OK_STATUS) resp.setResults(resultRowSet) diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/RowSet.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/RowSet.scala index 2bb16622eac..22e09f38138 100644 --- a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/RowSet.scala +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/RowSet.scala @@ -17,233 +17,16 @@ package org.apache.kyuubi.engine.trino.schema -import java.nio.ByteBuffer import java.nio.charset.StandardCharsets -import java.util import scala.collection.JavaConverters._ import io.trino.client.ClientStandardTypes._ import io.trino.client.ClientTypeSignature -import io.trino.client.Column import io.trino.client.Row -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TBinaryColumn -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TBoolColumn -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TBoolValue -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TByteColumn -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TByteValue -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TColumn -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TColumnValue -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TDoubleColumn -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TDoubleValue -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TI16Column -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TI16Value -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TI32Column -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TI32Value -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TI64Column -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TI64Value -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TProtocolVersion -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TRow -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TRowSet -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TStringColumn -import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TStringValue -import org.apache.kyuubi.util.RowSetUtils.bitSetToBuffer - object RowSet { - def toTRowSet( - rows: Seq[List[_]], - schema: List[Column], - protocolVersion: TProtocolVersion): TRowSet = { - if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) { - toRowBasedSet(rows, schema) - } else { - toColumnBasedSet(rows, schema) - } - } - - def toRowBasedSet(rows: Seq[List[_]], schema: List[Column]): TRowSet = { - val rowSize = rows.length - val tRows = new util.ArrayList[TRow](rowSize) - var i = 0 - while (i < rowSize) { - val row = rows(i) - val tRow = new TRow() - val columnSize = row.size - var j = 0 - while (j < columnSize) { - val columnValue = toTColumnValue(j, row, schema) - tRow.addToColVals(columnValue) - j += 1 - } - tRows.add(tRow) - i += 1 - } - new TRowSet(0, tRows) - } - - def toColumnBasedSet(rows: Seq[List[_]], schema: List[Column]): TRowSet = { - val size = rows.size - val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](size)) - val columnSize = schema.length - var i = 0 - while (i < columnSize) { - val field = schema(i) - val tColumn = toTColumn(rows, i, field.getTypeSignature) - tRowSet.addToColumns(tColumn) - i += 1 - } - tRowSet - } - - private def toTColumn( - rows: Seq[Seq[Any]], - ordinal: Int, - typ: ClientTypeSignature): TColumn = { - val nulls = new java.util.BitSet() - typ.getRawType match { - case BOOLEAN => - val values = getOrSetAsNull[java.lang.Boolean](rows, ordinal, nulls, true) - TColumn.boolVal(new TBoolColumn(values, nulls)) - - case TINYINT => - val values = getOrSetAsNull[java.lang.Byte](rows, ordinal, nulls, 0.toByte) - TColumn.byteVal(new TByteColumn(values, nulls)) - - case SMALLINT => - val values = getOrSetAsNull[java.lang.Short](rows, ordinal, nulls, 0.toShort) - TColumn.i16Val(new TI16Column(values, nulls)) - - case INTEGER => - val values = getOrSetAsNull[java.lang.Integer](rows, ordinal, nulls, 0) - TColumn.i32Val(new TI32Column(values, nulls)) - - case BIGINT => - val values = getOrSetAsNull[java.lang.Long](rows, ordinal, nulls, 0L) - TColumn.i64Val(new TI64Column(values, nulls)) - - case REAL => - val values = getOrSetAsNull[java.lang.Float](rows, ordinal, nulls, 0.toFloat) - .asScala.map(n => java.lang.Double.valueOf(n.toString)).asJava - TColumn.doubleVal(new TDoubleColumn(values, nulls)) - - case DOUBLE => - val values = getOrSetAsNull[java.lang.Double](rows, ordinal, nulls, 0.toDouble) - TColumn.doubleVal(new TDoubleColumn(values, nulls)) - - case VARCHAR => - val values = getOrSetAsNull[String](rows, ordinal, nulls, "") - TColumn.stringVal(new TStringColumn(values, nulls)) - - case VARBINARY => - val values = getOrSetAsNull[Array[Byte]](rows, ordinal, nulls, Array()) - .asScala - .map(ByteBuffer.wrap) - .asJava - TColumn.binaryVal(new TBinaryColumn(values, nulls)) - - case _ => - val rowSize = rows.length - val values = new util.ArrayList[String](rowSize) - var i = 0 - while (i < rowSize) { - val row = rows(i) - nulls.set(i, row(ordinal) == null) - val value = - if (row(ordinal) == null) { - "" - } else { - toHiveString(row(ordinal), typ) - } - values.add(value) - i += 1 - } - TColumn.stringVal(new TStringColumn(values, nulls)) - } - } - - private def getOrSetAsNull[T]( - rows: Seq[Seq[Any]], - ordinal: Int, - nulls: java.util.BitSet, - defaultVal: T): java.util.List[T] = { - val size = rows.length - val ret = new java.util.ArrayList[T](size) - var idx = 0 - while (idx < size) { - val row = rows(idx) - val isNull = row(ordinal) == null - if (isNull) { - nulls.set(idx, true) - ret.add(idx, defaultVal) - } else { - ret.add(idx, row(ordinal).asInstanceOf[T]) - } - idx += 1 - } - ret - } - - private def toTColumnValue( - ordinal: Int, - row: List[Any], - types: List[Column]): TColumnValue = { - - types(ordinal).getTypeSignature.getRawType match { - case BOOLEAN => - val boolValue = new TBoolValue - if (row(ordinal) != null) boolValue.setValue(row(ordinal).asInstanceOf[Boolean]) - TColumnValue.boolVal(boolValue) - - case TINYINT => - val byteValue = new TByteValue - if (row(ordinal) != null) byteValue.setValue(row(ordinal).asInstanceOf[Byte]) - TColumnValue.byteVal(byteValue) - - case SMALLINT => - val tI16Value = new TI16Value - if (row(ordinal) != null) tI16Value.setValue(row(ordinal).asInstanceOf[Short]) - TColumnValue.i16Val(tI16Value) - - case INTEGER => - val tI32Value = new TI32Value - if (row(ordinal) != null) tI32Value.setValue(row(ordinal).asInstanceOf[Int]) - TColumnValue.i32Val(tI32Value) - - case BIGINT => - val tI64Value = new TI64Value - if (row(ordinal) != null) tI64Value.setValue(row(ordinal).asInstanceOf[Long]) - TColumnValue.i64Val(tI64Value) - - case REAL => - val tDoubleValue = new TDoubleValue - if (row(ordinal) != null) { - val doubleValue = java.lang.Double.valueOf(row(ordinal).asInstanceOf[Float].toString) - tDoubleValue.setValue(doubleValue) - } - TColumnValue.doubleVal(tDoubleValue) - - case DOUBLE => - val tDoubleValue = new TDoubleValue - if (row(ordinal) != null) tDoubleValue.setValue(row(ordinal).asInstanceOf[Double]) - TColumnValue.doubleVal(tDoubleValue) - - case VARCHAR => - val tStringValue = new TStringValue - if (row(ordinal) != null) tStringValue.setValue(row(ordinal).asInstanceOf[String]) - TColumnValue.stringVal(tStringValue) - - case _ => - val tStrValue = new TStringValue - if (row(ordinal) != null) { - tStrValue.setValue( - toHiveString(row(ordinal), types(ordinal).getTypeSignature)) - } - TColumnValue.stringVal(tStrValue) - } - } - /** * A simpler impl of Trino's toHiveString */ diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/TrinoTRowSetGenerator.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/TrinoTRowSetGenerator.scala new file mode 100644 index 00000000000..9c323a5089b --- /dev/null +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/TrinoTRowSetGenerator.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino.schema + +import io.trino.client.{ClientTypeSignature, Column} +import io.trino.client.ClientStandardTypes._ + +import org.apache.kyuubi.engine.schema.AbstractTRowSetGenerator +import org.apache.kyuubi.engine.trino.schema.RowSet.toHiveString +import org.apache.kyuubi.shaded.hive.service.rpc.thrift._ +import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TTypeId._ +import org.apache.kyuubi.util.RowSetUtils._ + +class TrinoTRowSetGenerator + extends AbstractTRowSetGenerator[Seq[Column], Seq[_], ClientTypeSignature] { + + override def getColumnSizeFromSchemaType(schema: Seq[Column]): Int = schema.length + + override def getColumnType(schema: Seq[Column], ordinal: Int): ClientTypeSignature = { + schema(ordinal).getTypeSignature + } + + override def isColumnNullAt(row: Seq[_], ordinal: Int): Boolean = + row(ordinal) == null + + override def getColumnAs[T](row: Seq[_], ordinal: Int): T = + row(ordinal).asInstanceOf[T] + + override def toTColumn(rows: Seq[Seq[_]], ordinal: Int, typ: ClientTypeSignature): TColumn = { + val nulls = new java.util.BitSet() + typ.getRawType match { + case BOOLEAN => toTTypeColumn(BOOLEAN_TYPE, rows, ordinal) + case TINYINT => toTTypeColumn(BINARY_TYPE, rows, ordinal) + case SMALLINT => toTTypeColumn(TINYINT_TYPE, rows, ordinal) + case INTEGER => toTTypeColumn(INT_TYPE, rows, ordinal) + case BIGINT => toTTypeColumn(BIGINT_TYPE, rows, ordinal) + case REAL => toTTypeColumn(FLOAT_TYPE, rows, ordinal) + case DOUBLE => toTTypeColumn(DOUBLE_TYPE, rows, ordinal) + case VARCHAR => toTTypeColumn(STRING_TYPE, rows, ordinal) + case VARBINARY => toTTypeColumn(ARRAY_TYPE, rows, ordinal) + case _ => + val rowSize = rows.length + val values = new java.util.ArrayList[String](rowSize) + var i = 0 + while (i < rowSize) { + val row = rows(i) + val isNull = isColumnNullAt(row, ordinal) + nulls.set(i, isNull) + val value = if (isNull) { + "" + } else { + toHiveString(row(ordinal), typ) + } + values.add(value) + i += 1 + } + TColumn.stringVal(new TStringColumn(values, nulls)) + } + } + + override def toTColumnValue(ordinal: Int, row: Seq[_], types: Seq[Column]): TColumnValue = { + getColumnType(types, ordinal).getRawType match { + case BOOLEAN => toTTypeColumnVal(BOOLEAN_TYPE, row, ordinal) + case TINYINT => toTTypeColumnVal(BINARY_TYPE, row, ordinal) + case SMALLINT => toTTypeColumnVal(TINYINT_TYPE, row, ordinal) + case INTEGER => toTTypeColumnVal(INT_TYPE, row, ordinal) + case BIGINT => toTTypeColumnVal(BIGINT_TYPE, row, ordinal) + case REAL => toTTypeColumnVal(FLOAT_TYPE, row, ordinal) + case DOUBLE => toTTypeColumnVal(DOUBLE_TYPE, row, ordinal) + case VARCHAR => toTTypeColumnVal(STRING_TYPE, row, ordinal) + case _ => + val tStrValue = new TStringValue + if (row(ordinal) != null) { + tStrValue.setValue( + toHiveString(row(ordinal), types(ordinal).getTypeSignature)) + } + TColumnValue.stringVal(tStrValue) + } + } + +} diff --git a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/schema/RowSetSuite.scala b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/schema/RowSetSuite.scala index acc55d5a3d1..461c453ecd2 100644 --- a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/schema/RowSetSuite.scala +++ b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/schema/RowSetSuite.scala @@ -126,7 +126,7 @@ class RowSetSuite extends KyuubiFunSuite { def uuidSuffix(value: Int): String = if (value > 9) value.toString else s"f$value" test("column based set") { - val tRowSet = RowSet.toColumnBasedSet(rows, schema) + val tRowSet = new TrinoTRowSetGenerator().toColumnBasedSet(rows, schema) assert(tRowSet.getColumns.size() === schema.size) assert(tRowSet.getRowsSize === 0) @@ -277,7 +277,7 @@ class RowSetSuite extends KyuubiFunSuite { } test("row based set") { - val tRowSet = RowSet.toRowBasedSet(rows, schema) + val tRowSet = new TrinoTRowSetGenerator().toRowBasedSet(rows, schema) assert(tRowSet.getColumnCount === 0) assert(tRowSet.getRowsSize === rows.size) val iter = tRowSet.getRowsIterator @@ -333,7 +333,7 @@ class RowSetSuite extends KyuubiFunSuite { test("to row set") { TProtocolVersion.values().foreach { proto => - val set = RowSet.toTRowSet(rows, schema, proto) + val set = new TrinoTRowSetGenerator().toTRowSet(rows, schema, proto) if (proto.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) { assert(!set.isSetColumns, proto.toString) assert(set.isSetRows, proto.toString) diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/engine/schema/AbstractTRowSetGenerator.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/engine/schema/AbstractTRowSetGenerator.scala new file mode 100644 index 00000000000..365ed7298b1 --- /dev/null +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/engine/schema/AbstractTRowSetGenerator.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.schema +import java.nio.ByteBuffer +import java.util.{ArrayList => JArrayList, BitSet => JBitSet, List => JList} + +import scala.collection.JavaConverters._ + +import org.apache.kyuubi.shaded.hive.service.rpc.thrift._ +import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TTypeId._ +import org.apache.kyuubi.util.RowSetUtils.bitSetToBuffer + +trait AbstractTRowSetGenerator[SchemaT, RowT, ColumnT] { + + protected def getColumnSizeFromSchemaType(schema: SchemaT): Int + + protected def getColumnType(schema: SchemaT, ordinal: Int): ColumnT + + protected def isColumnNullAt(row: RowT, ordinal: Int): Boolean + + protected def getColumnAs[T](row: RowT, ordinal: Int): T + + protected def toTColumn(rows: Seq[RowT], ordinal: Int, typ: ColumnT): TColumn + + protected def toTColumnValue(ordinal: Int, row: RowT, types: SchemaT): TColumnValue + + def toTRowSet( + rows: Seq[RowT], + schema: SchemaT, + protocolVersion: TProtocolVersion): TRowSet = { + if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) { + toRowBasedSet(rows, schema) + } else { + toColumnBasedSet(rows, schema) + } + } + + def toRowBasedSet(rows: Seq[RowT], schema: SchemaT): TRowSet = { + val rowSize = rows.length + val tRows = new JArrayList[TRow](rowSize) + var i = 0 + while (i < rowSize) { + val row = rows(i) + var j = 0 + val columnSize = getColumnSizeFromSchemaType(schema) + val tColumnValues = new JArrayList[TColumnValue](columnSize) + while (j < columnSize) { + val columnValue = toTColumnValue(j, row, schema) + tColumnValues.add(columnValue) + j += 1 + } + i += 1 + val tRow = new TRow(tColumnValues) + tRows.add(tRow) + } + new TRowSet(0, tRows) + } + + def toColumnBasedSet(rows: Seq[RowT], schema: SchemaT): TRowSet = { + val rowSize = rows.length + val tRowSet = new TRowSet(0, new JArrayList[TRow](rowSize)) + var i = 0 + val columnSize = getColumnSizeFromSchemaType(schema) + val tColumns = new JArrayList[TColumn](columnSize) + while (i < columnSize) { + val tColumn = toTColumn(rows, i, getColumnType(schema, i)) + tColumns.add(tColumn) + i += 1 + } + tRowSet.setColumns(tColumns) + tRowSet + } + + protected def getOrSetAsNull[T]( + rows: Seq[RowT], + ordinal: Int, + nulls: JBitSet, + defaultVal: T): JList[T] = { + val size = rows.length + val ret = new JArrayList[T](size) + var idx = 0 + while (idx < size) { + val row = rows(idx) + val isNull = isColumnNullAt(row, ordinal) + if (isNull) { + nulls.set(idx, true) + ret.add(defaultVal) + } else { + ret.add(getColumnAs[T](row, ordinal)) + } + idx += 1 + } + ret + } + + protected def toTTypeColumnVal(typeId: TTypeId, row: RowT, ordinal: Int): TColumnValue = { + def isNull = isColumnNullAt(row, ordinal) + typeId match { + case BOOLEAN_TYPE => + val boolValue = new TBoolValue + if (!isNull) boolValue.setValue(getColumnAs[java.lang.Boolean](row, ordinal)) + TColumnValue.boolVal(boolValue) + + case BINARY_TYPE => + val byteValue = new TByteValue + if (!isNull) byteValue.setValue(getColumnAs[java.lang.Byte](row, ordinal)) + TColumnValue.byteVal(byteValue) + + case TINYINT_TYPE => + val tI16Value = new TI16Value + if (!isNull) tI16Value.setValue(getColumnAs[java.lang.Short](row, ordinal)) + TColumnValue.i16Val(tI16Value) + + case INT_TYPE => + val tI32Value = new TI32Value + if (!isNull) tI32Value.setValue(getColumnAs[java.lang.Integer](row, ordinal)) + TColumnValue.i32Val(tI32Value) + + case BIGINT_TYPE => + val tI64Value = new TI64Value + if (!isNull) tI64Value.setValue(getColumnAs[java.lang.Long](row, ordinal)) + TColumnValue.i64Val(tI64Value) + + case FLOAT_TYPE => + val tDoubleValue = new TDoubleValue + if (!isNull) tDoubleValue.setValue(getColumnAs[java.lang.Float](row, ordinal).toDouble) + TColumnValue.doubleVal(tDoubleValue) + + case DOUBLE_TYPE => + val tDoubleValue = new TDoubleValue + if (!isNull) tDoubleValue.setValue(getColumnAs[java.lang.Double](row, ordinal)) + TColumnValue.doubleVal(tDoubleValue) + + case STRING_TYPE => + val tStringValue = new TStringValue + if (!isNull) tStringValue.setValue(getColumnAs[String](row, ordinal)) + TColumnValue.stringVal(tStringValue) + + case otherType => + throw new UnsupportedOperationException(s"unsupported type $otherType for toTTypeColumnVal") + } + } + + protected def toTTypeColumn(typeId: TTypeId, rows: Seq[RowT], ordinal: Int): TColumn = { + val nulls = new JBitSet() + typeId match { + case BOOLEAN_TYPE => + val values = getOrSetAsNull[java.lang.Boolean](rows, ordinal, nulls, true) + TColumn.boolVal(new TBoolColumn(values, nulls)) + + case BINARY_TYPE => + val values = getOrSetAsNull[java.lang.Byte](rows, ordinal, nulls, 0.toByte) + TColumn.byteVal(new TByteColumn(values, nulls)) + + case SMALLINT_TYPE => + val values = getOrSetAsNull[java.lang.Short](rows, ordinal, nulls, 0.toShort) + TColumn.i16Val(new TI16Column(values, nulls)) + + case TINYINT_TYPE => + val values = getOrSetAsNull[java.lang.Short](rows, ordinal, nulls, 0.toShort) + TColumn.i16Val(new TI16Column(values, nulls)) + + case INT_TYPE => + val values = getOrSetAsNull[java.lang.Integer](rows, ordinal, nulls, 0) + TColumn.i32Val(new TI32Column(values, nulls)) + + case BIGINT_TYPE => + val values = getOrSetAsNull[java.lang.Long](rows, ordinal, nulls, 0L) + TColumn.i64Val(new TI64Column(values, nulls)) + + case FLOAT_TYPE => + val values = getOrSetAsNull[java.lang.Float](rows, ordinal, nulls, 0.toFloat) + .asScala.map(n => java.lang.Double.valueOf(n.toString)).asJava + TColumn.doubleVal(new TDoubleColumn(values, nulls)) + + case DOUBLE_TYPE => + val values = getOrSetAsNull[java.lang.Double](rows, ordinal, nulls, 0.toDouble) + TColumn.doubleVal(new TDoubleColumn(values, nulls)) + + case STRING_TYPE => + val values = getOrSetAsNull[java.lang.String](rows, ordinal, nulls, "") + TColumn.stringVal(new TStringColumn(values, nulls)) + + case ARRAY_TYPE => + val values = getOrSetAsNull[Array[Byte]](rows, ordinal, nulls, Array()) + .asScala + .map(ByteBuffer.wrap) + .asJava + TColumn.binaryVal(new TBinaryColumn(values, nulls)) + + case otherType => + throw new UnsupportedOperationException(s"unsupported type $otherType for toTTypeColumnVal") + } + } +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/plan/command/RunnableCommand.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/plan/command/RunnableCommand.scala index 8f19d7f7a24..cdfb515bd3a 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/plan/command/RunnableCommand.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/plan/command/RunnableCommand.scala @@ -22,7 +22,7 @@ import org.apache.kyuubi.operation.FetchOrientation.{FETCH_FIRST, FETCH_NEXT, FE import org.apache.kyuubi.session.KyuubiSession import org.apache.kyuubi.shaded.hive.service.rpc.thrift.{TProtocolVersion, TRowSet} import org.apache.kyuubi.sql.plan.KyuubiTreeNode -import org.apache.kyuubi.sql.schema.{Row, RowSetHelper, Schema} +import org.apache.kyuubi.sql.schema.{Row, Schema, ServerTRowSetGenerator} trait RunnableCommand extends KyuubiTreeNode { @@ -44,7 +44,7 @@ trait RunnableCommand extends KyuubiTreeNode { case FETCH_FIRST => iter.fetchAbsolute(0) } val taken = iter.take(rowSetSize) - val resultRowSet = RowSetHelper.toTRowSet( + val resultRowSet = new ServerTRowSetGenerator().toTRowSet( taken.toList, resultSchema, protocolVersion) diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/schema/RowSetHelper.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/schema/RowSetHelper.scala deleted file mode 100644 index 7a5fab0822e..00000000000 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/schema/RowSetHelper.scala +++ /dev/null @@ -1,209 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.kyuubi.sql.schema - -import java.util - -import scala.collection.JavaConverters._ - -import org.apache.kyuubi.shaded.hive.service.rpc.thrift._ -import org.apache.kyuubi.util.RowSetUtils._ - -object RowSetHelper { - - def toTRowSet( - rows: Seq[Row], - schema: Schema, - protocolVersion: TProtocolVersion): TRowSet = { - if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) { - toRowBasedSet(rows, schema) - } else { - toColumnBasedSet(rows, schema) - } - } - - def toRowBasedSet(rows: Seq[Row], schema: Schema): TRowSet = { - var i = 0 - val rowSize = rows.length - val tRows = new java.util.ArrayList[TRow](rowSize) - while (i < rowSize) { - val row = rows(i) - val tRow = new TRow() - var j = 0 - val columnSize = row.length - while (j < columnSize) { - val columnValue = toTColumnValue(j, row, schema) - tRow.addToColVals(columnValue) - j += 1 - } - i += 1 - tRows.add(tRow) - } - new TRowSet(0, tRows) - } - - private def toTColumnValue( - ordinal: Int, - row: Row, - types: Schema): TColumnValue = { - types(ordinal).dataType match { - case TTypeId.BOOLEAN_TYPE => - val boolValue = new TBoolValue - if (!row.isNullAt(ordinal)) boolValue.setValue(row.getBoolean(ordinal)) - TColumnValue.boolVal(boolValue) - - case TTypeId.BINARY_TYPE => - val byteValue = new TByteValue - if (!row.isNullAt(ordinal)) byteValue.setValue(row.getByte(ordinal)) - TColumnValue.byteVal(byteValue) - - case TTypeId.TINYINT_TYPE => - val tI16Value = new TI16Value - if (!row.isNullAt(ordinal)) tI16Value.setValue(row.getShort(ordinal)) - TColumnValue.i16Val(tI16Value) - - case TTypeId.INT_TYPE => - val tI32Value = new TI32Value - if (!row.isNullAt(ordinal)) tI32Value.setValue(row.getInt(ordinal)) - TColumnValue.i32Val(tI32Value) - - case TTypeId.BIGINT_TYPE => - val tI64Value = new TI64Value - if (!row.isNullAt(ordinal)) tI64Value.setValue(row.getLong(ordinal)) - TColumnValue.i64Val(tI64Value) - - case TTypeId.FLOAT_TYPE => - val tDoubleValue = new TDoubleValue - if (!row.isNullAt(ordinal)) { - val doubleValue = java.lang.Double.valueOf(row.getFloat(ordinal).toString) - tDoubleValue.setValue(doubleValue) - } - TColumnValue.doubleVal(tDoubleValue) - - case TTypeId.DOUBLE_TYPE => - val tDoubleValue = new TDoubleValue - if (!row.isNullAt(ordinal)) tDoubleValue.setValue(row.getDouble(ordinal)) - TColumnValue.doubleVal(tDoubleValue) - - case TTypeId.STRING_TYPE => - val tStringValue = new TStringValue - if (!row.isNullAt(ordinal)) tStringValue.setValue(row.getString(ordinal)) - TColumnValue.stringVal(tStringValue) - - case _ => - val tStrValue = new TStringValue - if (!row.isNullAt(ordinal)) { - tStrValue.setValue((row.get(ordinal), types(ordinal).dataType).toString()) - } - TColumnValue.stringVal(tStrValue) - } - } - - def toColumnBasedSet(rows: Seq[Row], schema: Schema): TRowSet = { - val rowSize = rows.length - val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize)) - var i = 0 - val columnSize = schema.length - while (i < columnSize) { - val field = schema(i) - val tColumn = toTColumn(rows, i, field.dataType) - tRowSet.addToColumns(tColumn) - i += 1 - } - tRowSet - } - - private def toTColumn(rows: Seq[Row], ordinal: Int, typ: TTypeId): TColumn = { - val nulls = new java.util.BitSet() - typ match { - case TTypeId.BOOLEAN_TYPE => - val values = getOrSetAsNull[java.lang.Boolean](rows, ordinal, nulls, true) - TColumn.boolVal(new TBoolColumn(values, nulls)) - - case TTypeId.BINARY_TYPE => - val values = getOrSetAsNull[java.lang.Byte](rows, ordinal, nulls, 0.toByte) - TColumn.byteVal(new TByteColumn(values, nulls)) - - case TTypeId.TINYINT_TYPE => - val values = getOrSetAsNull[java.lang.Short](rows, ordinal, nulls, 0.toShort) - TColumn.i16Val(new TI16Column(values, nulls)) - - case TTypeId.INT_TYPE => - val values = getOrSetAsNull[java.lang.Integer](rows, ordinal, nulls, 0) - TColumn.i32Val(new TI32Column(values, nulls)) - - case TTypeId.BIGINT_TYPE => - val values = getOrSetAsNull[java.lang.Long](rows, ordinal, nulls, 0L) - TColumn.i64Val(new TI64Column(values, nulls)) - - case TTypeId.FLOAT_TYPE => - val values = getOrSetAsNull[java.lang.Float](rows, ordinal, nulls, 0.toFloat) - .asScala.map(n => java.lang.Double.valueOf(n.toString)).asJava - TColumn.doubleVal(new TDoubleColumn(values, nulls)) - - case TTypeId.DOUBLE_TYPE => - val values = getOrSetAsNull[java.lang.Double](rows, ordinal, nulls, 0.toDouble) - TColumn.doubleVal(new TDoubleColumn(values, nulls)) - - case TTypeId.STRING_TYPE => - val values: util.List[String] = getOrSetAsNull[java.lang.String](rows, ordinal, nulls, "") - TColumn.stringVal(new TStringColumn(values, nulls)) - - case _ => - var i = 0 - val rowSize = rows.length - val values = new java.util.ArrayList[String](rowSize) - while (i < rowSize) { - val row = rows(i) - nulls.set(i, row.isNullAt(ordinal)) - val value = - if (row.isNullAt(ordinal)) { - "" - } else { - (row.get(ordinal), typ).toString() - } - values.add(value) - i += 1 - } - TColumn.stringVal(new TStringColumn(values, nulls)) - } - } - - private def getOrSetAsNull[T]( - rows: Seq[Row], - ordinal: Int, - nulls: java.util.BitSet, - defaultVal: T): java.util.List[T] = { - val size = rows.length - val ret = new java.util.ArrayList[T](size) - var idx = 0 - while (idx < size) { - val row = rows(idx) - val isNull = row.isNullAt(ordinal) - if (isNull) { - nulls.set(idx, true) - ret.add(idx, defaultVal) - } else { - ret.add(idx, row.getAs[T](ordinal)) - } - idx += 1 - } - ret - } - -} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/schema/ServerTRowSetGenerator.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/schema/ServerTRowSetGenerator.scala new file mode 100644 index 00000000000..e1a9d55a6e9 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/schema/ServerTRowSetGenerator.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.sql.schema + +import org.apache.kyuubi.engine.schema.AbstractTRowSetGenerator +import org.apache.kyuubi.shaded.hive.service.rpc.thrift._ +import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TTypeId._ +import org.apache.kyuubi.util.RowSetUtils._ + +class ServerTRowSetGenerator + extends AbstractTRowSetGenerator[Schema, Row, TTypeId] { + + override def getColumnSizeFromSchemaType(schema: Schema): Int = schema.length + + override def getColumnType(schema: Schema, ordinal: Int): TTypeId = schema(ordinal).dataType + + override def isColumnNullAt(row: Row, ordinal: Int): Boolean = row.isNullAt(ordinal) + + override def getColumnAs[T](row: Row, ordinal: Int): T = row.getAs[T](ordinal) + + override def toTColumn(rows: Seq[Row], ordinal: Int, typ: TTypeId): TColumn = { + val nulls = new java.util.BitSet() + typ match { + case t @ (BOOLEAN_TYPE | BINARY_TYPE | BINARY_TYPE | TINYINT_TYPE | INT_TYPE | + BIGINT_TYPE | FLOAT_TYPE | DOUBLE_TYPE | STRING_TYPE) => + toTTypeColumn(t, rows, ordinal) + + case _ => + var i = 0 + val rowSize = rows.length + val values = new java.util.ArrayList[String](rowSize) + while (i < rowSize) { + val row = rows(i) + val isNull = isColumnNullAt(row, ordinal) + nulls.set(i, isNull) + val value = if (isNull) { + "" + } else { + (row.get(ordinal), typ).toString() + } + values.add(value) + i += 1 + } + TColumn.stringVal(new TStringColumn(values, nulls)) + } + } + + override def toTColumnValue(ordinal: Int, row: Row, types: Schema): TColumnValue = { + getColumnType(types, ordinal) match { + case t @ (BOOLEAN_TYPE | BINARY_TYPE | BINARY_TYPE | TINYINT_TYPE | INT_TYPE | + BIGINT_TYPE | FLOAT_TYPE | DOUBLE_TYPE | STRING_TYPE) => + toTTypeColumnVal(t, row, ordinal) + + case _ => + val tStrValue = new TStringValue + if (!isColumnNullAt(row, ordinal)) { + tStrValue.setValue((row.get(ordinal), types(ordinal).dataType).toString()) + } + TColumnValue.stringVal(tStrValue) + } + } + +}