From 412102ba5ba96430b83ec7a98403ce3ec916739c Mon Sep 17 00:00:00 2001 From: gnehil Date: Wed, 9 Aug 2023 10:23:27 +0800 Subject: [PATCH 01/10] init --- .../doris/spark/backend/BackendClient.java | 1 - .../org/apache/doris/spark/cfg/Settings.java | 3 +- .../doris/spark/http/RequestBuilder.scala | 5 ++ .../org/apache/doris/spark/package.scala | 6 +- .../spark/writer/DorisStreamLoader.scala | 20 +++++ .../doris/spark/writer/DorisWriterV2.scala | 86 +++++++++++++++++++ .../doris/spark/writer/RowSerializer.scala | 75 ++++++++++++++++ 7 files changed, 190 insertions(+), 6 deletions(-) create mode 100644 spark-doris-connector/src/main/scala/org/apache/doris/spark/http/RequestBuilder.scala create mode 100644 spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisStreamLoader.scala create mode 100644 spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriterV2.scala create mode 100644 spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/RowSerializer.scala diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java index aaafe096..e6de6dad 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java @@ -29,7 +29,6 @@ import org.apache.doris.spark.exception.ConnectedFailedException; import org.apache.doris.spark.exception.DorisException; import org.apache.doris.spark.exception.DorisInternalException; -import org.apache.doris.spark.util.ErrorMessages; import org.apache.doris.spark.cfg.Settings; import org.apache.doris.spark.serialization.Routing; import org.apache.thrift.TConfiguration; diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java index 798ec8cf..45a132df 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java @@ -23,8 +23,7 @@ import org.apache.commons.lang3.StringUtils; import org.apache.doris.spark.exception.IllegalArgumentException; -import org.apache.doris.spark.util.ErrorMessages; -import org.apache.doris.spark.util.IOUtils; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/http/RequestBuilder.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/http/RequestBuilder.scala new file mode 100644 index 00000000..4ce3dd01 --- /dev/null +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/http/RequestBuilder.scala @@ -0,0 +1,5 @@ +package org.apache.doris.spark.http + +class RequestBuilder { + +} diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala index d08bdc0d..9dee5158 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala @@ -18,18 +18,18 @@ package org.apache.doris import scala.language.implicitConversions - import org.apache.doris.spark.rdd.DorisSpark import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD package object spark { - implicit def sparkContextFunctions(sc: SparkContext) = new SparkContextFunctions(sc) + implicit def sparkContextFunctions(sc: SparkContext): SparkContextFunctions = new SparkContextFunctions(sc) class SparkContextFunctions(sc: SparkContext) extends Serializable { def dorisRDD( tableIdentifier: Option[String] = None, query: Option[String] = None, - cfg: Option[Map[String, String]] = None) = + cfg: Option[Map[String, String]] = None): RDD[AnyRef] = DorisSpark.dorisRDD(sc, tableIdentifier, query, cfg) } } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisStreamLoader.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisStreamLoader.scala new file mode 100644 index 00000000..ee07e3af --- /dev/null +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisStreamLoader.scala @@ -0,0 +1,20 @@ +package org.apache.doris.spark.writer + +import org.apache.doris.spark.cfg.SparkSettings + +class DorisStreamLoader(settings: SparkSettings) { + + def start(): Unit = Nil + + def load(rowData: Array[Byte]): Unit = { + + + } + + def stop(): Unit = Nil + + def commit(): Unit = Nil + + def abort(txnId: Long): Unit = Nil + +} diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriterV2.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriterV2.scala new file mode 100644 index 00000000..cc0e55c3 --- /dev/null +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriterV2.scala @@ -0,0 +1,86 @@ +package org.apache.doris.spark.writer + +import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings} +import org.apache.doris.spark.listener.DorisTransactionListener +import org.apache.doris.spark.sql.Utils +import org.apache.spark.sql.DataFrame +import org.slf4j.{Logger, LoggerFactory} + +import java.io.IOException +import java.time.Duration +import java.util +import java.util.Objects +import scala.collection.mutable +import scala.util.{Failure, Success, Try} + +class DorisWriterV2(settings: SparkSettings) { + + private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriterV2]) + + val batchSize: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE, + ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT) + + private val enable2PC: Boolean = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC, + ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT); + + private val dorisStreamLoader = new DorisStreamLoader(settings) + + def write(dataFrame: DataFrame): Unit = { + + val sc = dataFrame.sqlContext.sparkContext + val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc") + if (enable2PC) { + sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader)) + } + + + val rowSerializer = RowSerializer(settings, dataFrame.columns) + + dataFrame.rdd + .map(rowSerializer.serialize) + .foreachPartition(flush) + + /** + * flush data to Doris and do retry when flush error + * + */ + def flush(rowIter: Iterator[Array[Byte]]): Unit = { + + Try { + dorisStreamLoader.start() + rowIter.foreach(dorisStreamLoader.load) + } match { + case Success(txnIds) => { + dorisStreamLoader.stop() + if (enable2PC) txnIds.forEach(txnId => preCommittedTxnAcc.add(txnId)) + } + case Failure(e) => + } + + Try(dorisStreamLoader.load(rowData)) match { + case Success(txnIds) => if (enable2PC) txnIds.forEach(txnId => preCommittedTxnAcc.add(txnId)) + case Failure(e) => + if (enable2PC) { + // if task run failed, acc value will not be returned to driver, + // should abort all pre committed transactions inside the task + logger.info("load task failed, start aborting previously pre-committed transactions") + val abortFailedTxnIds = mutable.Buffer[Int]() + preCommittedTxnAcc.value.forEach(txnId => { + Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) { + dorisStreamLoader.abort(txnId) + } match { + case Success(_) => + case Failure(_) => abortFailedTxnIds += txnId + } + }) + if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(",")) + preCommittedTxnAcc.reset() + } + throw new IOException( + s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) + } + } + + } + +} diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/RowSerializer.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/RowSerializer.scala new file mode 100644 index 00000000..68fd6ebc --- /dev/null +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/RowSerializer.scala @@ -0,0 +1,75 @@ +package org.apache.doris.spark.writer + +import com.fasterxml.jackson.databind.ObjectMapper +import org.apache.doris.spark.cfg.SparkSettings +import org.apache.spark.sql.Row + +import java.nio.charset.StandardCharsets +import java.sql.Timestamp + +class RowSerializer(format: String, + columns: Array[String], + columnSeparator: String + ) { + + private val mapper = new ObjectMapper(); + + def serialize(row: Row): Array[Byte] = { + if (columns.length != row.size) + return Array.empty[Byte] + format.toUpperCase match { + case "CSV" => toCsv(row) + case "JSON" => toJson(row) + case _ => throw new IllegalArgumentException("") + } + + } + + private def toCsv(row: Row): Array[Byte] = { + (0 to columns.length).map(i => row.get(i).toString).mkString(columnSeparator).getBytes(StandardCharsets.UTF_8) + } + + private def toJson(row: Row): Array[Byte] = { + mapper.writeValueAsBytes((0 to columns.length).map(i => { + val value = row.get(i) + value match { + case Timestamp => (columns(i), value.toString) + case _ => (columns(i), value) + } + }).toMap) + } + +} + +object RowSerializer { + + def apply(settings: SparkSettings, columns: Array[String]): RowSerializer = { + val format = settings.getProperty("format", "csv") + val columnSeparator = escapeString(settings.getProperty("column_separator", "\t")) + new RowSerializer(format, columns, columnSeparator) + } + + private def escapeString(hexData: String): String = { + if (hexData.startsWith("\\x") || hexData.startsWith("\\X")) { + try { + val data = hexData.substring(2) + val stringBuilder = new StringBuilder + var i = 0 + while (i < data.length) { + val hexByte = data.substring(i, i + 2) + val decimal = Integer.parseInt(hexByte, 16) + val character = decimal.toChar + stringBuilder.append(character) + + i += 2 + } + return stringBuilder.toString + } catch { + case e: Exception => + throw new RuntimeException("escape column_separator or line_delimiter error.{}", e) + } + } + hexData + } + +} From 999ec6a90be6b68838f58376d2c6e3272241786e Mon Sep 17 00:00:00 2001 From: gnehil Date: Mon, 14 Aug 2023 17:27:55 +0800 Subject: [PATCH 02/10] reduce mem use --- .../doris/spark/backend/BackendClient.java | 2 + .../org/apache/doris/spark/cfg/Settings.java | 2 + .../doris/spark/load/DorisStreamLoad.java | 31 ++------ .../apache/doris/spark/util/ListUtils.java | 28 ++++++-- .../doris/spark/writer/DorisWriter.scala | 6 +- .../doris/spark/writer/DorisWriterV2.scala | 72 +++++++++---------- .../doris/spark/writer/RowSerializer.scala | 7 +- 7 files changed, 75 insertions(+), 73 deletions(-) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java index e6de6dad..b10797b4 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java @@ -31,6 +31,8 @@ import org.apache.doris.spark.exception.DorisInternalException; import org.apache.doris.spark.cfg.Settings; import org.apache.doris.spark.serialization.Routing; +import org.apache.doris.spark.util.ErrorMessages; + import org.apache.thrift.TConfiguration; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java index 45a132df..c941fdfa 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java @@ -23,6 +23,8 @@ import org.apache.commons.lang3.StringUtils; import org.apache.doris.spark.exception.IllegalArgumentException; +import org.apache.doris.spark.util.ErrorMessages; +import org.apache.doris.spark.util.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index ac920cd0..123447cf 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -44,8 +44,10 @@ import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.util.EntityUtils; +import org.apache.spark.sql.Row; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import scala.collection.Seq; import java.io.IOException; import java.io.Serializable; @@ -177,27 +179,8 @@ public String toString() { } } - public List loadV2(List> rows, String[] dfColumns, Boolean enable2PC) throws StreamLoadException, JsonProcessingException { - - List loadData = parseLoadData(rows, dfColumns); - List txnIds = new ArrayList<>(loadData.size()); - - try { - for (String data : loadData) { - txnIds.add(load(data, enable2PC)); - } - } catch (StreamLoadException e) { - if (enable2PC && !txnIds.isEmpty()) { - LOG.error("load batch failed, abort previously pre-committed transactions"); - for (Integer txnId : txnIds) { - abort(txnId); - } - } - throw e; - } - - return txnIds; - + public int loadV2(List rows, String[] dfColumns, Boolean enable2PC) throws StreamLoadException, JsonProcessingException { + return load(parseLoadData(rows, dfColumns), enable2PC); } public List loadStream(List> rows, String[] dfColumns, Boolean enable2PC) @@ -410,7 +393,7 @@ public List load(String key) throws Exception { } - private List parseLoadData(List> rows, String[] dfColumns) throws StreamLoadException, JsonProcessingException { + private String parseLoadData(List rows, String[] dfColumns) throws StreamLoadException, JsonProcessingException { List loadDataList; @@ -428,7 +411,7 @@ private List parseLoadData(List> rows, String[] dfColumns) case "JSON": List> dataList = new ArrayList<>(); try { - for (List row : rows) { + for (Row row : rows) { Map dataMap = new HashMap<>(); if (dfColumns.length == row.size()) { for (int i = 0; i < dfColumns.length; i++) { @@ -448,8 +431,6 @@ private List parseLoadData(List> rows, String[] dfColumns) } - return loadDataList; - } private String generateLoadLabel() { diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java index d8d31b9e..5affb6f5 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java @@ -21,10 +21,10 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Lists; import org.apache.commons.lang3.exception.ExceptionUtils; +import org.apache.spark.sql.Row; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -33,11 +33,12 @@ public class ListUtils { private static final ObjectMapper MAPPER = new ObjectMapper(); - public static List getSerializedList(List> batch, - String lineDelimiter) throws JsonProcessingException { - List result = new ArrayList<>(); - divideAndSerialize(batch, result, lineDelimiter); - return result; + public static String getSerializedList(List> batch, + String lineDelimiter) throws JsonProcessingException { + // List result = new ArrayList<>(); + // divideAndSerialize(batch, result, lineDelimiter); + // return result; + return generateSerializedResult(batch, lineDelimiter); } /*** @@ -91,4 +92,19 @@ private static String generateSerializedResult(List> batch, } } + public static String mkString(Row row, String sep) { + StringBuilder builder = new StringBuilder(); + int n = row.size(); + if (n > 0) { + builder.append(row.get(0) == null ? "\\N" : row.get(0)); + int i = 1; + while (i < n) { + builder.append(sep); + builder.append(row.get(0) == null ? "\\N" : row.get(0)); + i++; + } + } + return builder.toString(); + } + } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala index e32267ee..36acddfd 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala @@ -21,6 +21,7 @@ import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings} import org.apache.doris.spark.listener.DorisTransactionListener import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad} import org.apache.doris.spark.sql.Utils +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType @@ -68,7 +69,6 @@ class DorisWriter(settings: SparkSettings) extends Serializable { resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize) } resultRdd - .map(_.toSeq.map(_.asInstanceOf[AnyRef]).toList.asJava) .foreachPartition(partition => { partition .grouped(batchSize) @@ -79,8 +79,8 @@ class DorisWriter(settings: SparkSettings) extends Serializable { * flush data to Doris and do retry when flush error * */ - def flush(batch: Seq[util.List[Object]], dfColumns: Array[String]): Unit = { - Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { + def flush(batch: Seq[Row], dfColumns: Array[String]): Unit = { + Utils.retry[Integer, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { dorisStreamLoader.loadV2(batch.asJava, dfColumns, enable2PC) } match { case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc) diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriterV2.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriterV2.scala index cc0e55c3..9077b584 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriterV2.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriterV2.scala @@ -29,9 +29,9 @@ class DorisWriterV2(settings: SparkSettings) { val sc = dataFrame.sqlContext.sparkContext val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc") - if (enable2PC) { - sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader)) - } + // if (enable2PC) { + // sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader)) + // } val rowSerializer = RowSerializer(settings, dataFrame.columns) @@ -46,39 +46,39 @@ class DorisWriterV2(settings: SparkSettings) { */ def flush(rowIter: Iterator[Array[Byte]]): Unit = { - Try { - dorisStreamLoader.start() - rowIter.foreach(dorisStreamLoader.load) - } match { - case Success(txnIds) => { - dorisStreamLoader.stop() - if (enable2PC) txnIds.forEach(txnId => preCommittedTxnAcc.add(txnId)) - } - case Failure(e) => - } - - Try(dorisStreamLoader.load(rowData)) match { - case Success(txnIds) => if (enable2PC) txnIds.forEach(txnId => preCommittedTxnAcc.add(txnId)) - case Failure(e) => - if (enable2PC) { - // if task run failed, acc value will not be returned to driver, - // should abort all pre committed transactions inside the task - logger.info("load task failed, start aborting previously pre-committed transactions") - val abortFailedTxnIds = mutable.Buffer[Int]() - preCommittedTxnAcc.value.forEach(txnId => { - Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) { - dorisStreamLoader.abort(txnId) - } match { - case Success(_) => - case Failure(_) => abortFailedTxnIds += txnId - } - }) - if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(",")) - preCommittedTxnAcc.reset() - } - throw new IOException( - s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) - } + // Try { + // dorisStreamLoader.start() + // rowIter.foreach(dorisStreamLoader.load) + // } match { + // case Success(txnIds) => { + // dorisStreamLoader.stop() + // if (enable2PC) txnIds.forEach(txnId => preCommittedTxnAcc.add(txnId)) + // } + // case Failure(e) => + // } + // + // Try(dorisStreamLoader.load(rowData)) match { + // case Success(txnIds) => if (enable2PC) txnIds.forEach(txnId => preCommittedTxnAcc.add(txnId)) + // case Failure(e) => + // if (enable2PC) { + // // if task run failed, acc value will not be returned to driver, + // // should abort all pre committed transactions inside the task + // logger.info("load task failed, start aborting previously pre-committed transactions") + // val abortFailedTxnIds = mutable.Buffer[Int]() + // preCommittedTxnAcc.value.forEach(txnId => { + // Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) { + // dorisStreamLoader.abort(txnId) + // } match { + // case Success(_) => + // case Failure(_) => abortFailedTxnIds += txnId + // } + // }) + // if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(",")) + // preCommittedTxnAcc.reset() + // } + // throw new IOException( + // s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) + // } } } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/RowSerializer.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/RowSerializer.scala index 68fd6ebc..64806131 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/RowSerializer.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/RowSerializer.scala @@ -32,9 +32,10 @@ class RowSerializer(format: String, private def toJson(row: Row): Array[Byte] = { mapper.writeValueAsBytes((0 to columns.length).map(i => { val value = row.get(i) - value match { - case Timestamp => (columns(i), value.toString) - case _ => (columns(i), value) + if (value.isInstanceOf[Timestamp]) { + (columns(i), value.toString) + } else { + (columns(i), value) } }).toMap) } From 85c6032185523839dcce507ea025fefa3303ff6c Mon Sep 17 00:00:00 2001 From: gnehil Date: Mon, 21 Aug 2023 18:04:02 +0800 Subject: [PATCH 03/10] optimize --- .../doris/spark/load/DorisStreamLoad.java | 191 ++++++++---------- .../doris/spark/load/RowInputStream.java | 111 ++++++++++ .../doris/spark/serialization/RowBatch.java | 38 ++-- .../org/apache/doris/spark/util/DataUtil.java | 85 ++++++++ .../apache/doris/spark/util/ListUtils.java | 26 +-- .../doris/spark/http/RequestBuilder.scala | 5 - .../spark/writer/DorisStreamLoader.scala | 20 -- .../doris/spark/writer/DorisWriter.scala | 1 - .../doris/spark/writer/DorisWriterV2.scala | 86 -------- .../doris/spark/writer/RowSerializer.scala | 76 ------- 10 files changed, 301 insertions(+), 338 deletions(-) create mode 100644 spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java delete mode 100644 spark-doris-connector/src/main/scala/org/apache/doris/spark/http/RequestBuilder.scala delete mode 100644 spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisStreamLoader.scala delete mode 100644 spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriterV2.scala delete mode 100644 spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/RowSerializer.scala diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index 123447cf..710a71e8 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -23,7 +23,6 @@ import org.apache.doris.spark.rest.models.BackendV2; import org.apache.doris.spark.rest.models.RespContent; import org.apache.doris.spark.util.DataUtil; -import org.apache.doris.spark.util.ListUtils; import org.apache.doris.spark.util.ResponseUtil; import com.fasterxml.jackson.core.JsonProcessingException; @@ -39,7 +38,9 @@ import org.apache.http.HttpStatus; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpPut; +import org.apache.http.client.methods.HttpRequestBase; import org.apache.http.entity.BufferedHttpEntity; +import org.apache.http.entity.InputStreamEntity; import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; @@ -47,12 +48,10 @@ import org.apache.spark.sql.Row; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import scala.collection.Seq; import java.io.IOException; import java.io.Serializable; import java.nio.charset.StandardCharsets; -import java.sql.Timestamp; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; @@ -65,37 +64,35 @@ import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; /** * DorisStreamLoad **/ public class DorisStreamLoad implements Serializable { - private String FIELD_DELIMITER; - private final String LINE_DELIMITER; - private static final String NULL_VALUE = "\\N"; private static final Logger LOG = LoggerFactory.getLogger(DorisStreamLoad.class); + private static final ObjectMapper MAPPER = new ObjectMapper(); + private final static List DORIS_SUCCESS_STATUS = new ArrayList<>(Arrays.asList("Success", "Publish Timeout")); - private static String loadUrlPattern = "http://%s/api/%s/%s/_stream_load?"; - private static String abortUrlPattern = "http://%s/api/%s/%s/_stream_load_2pc?"; + private static final String loadUrlPattern = "http://%s/api/%s/%s/_stream_load?"; + + private static final String abortUrlPattern = "http://%s/api/%s/%s/_stream_load_2pc?"; - private String user; - private String passwd; private String loadUrlStr; - private String db; - private String tbl; - private String authEncoded; - private String columns; - private String maxFilterRatio; - private Map streamLoadProp; + private final String db; + private final String tbl; + private final String authEncoded; + private final String columns; + private final String maxFilterRatio; + private final Map streamLoadProp; private static final long cacheExpireTimeout = 4 * 60; private final LoadingCache> cache; private final String fileType; - + private String FIELD_DELIMITER; + private final String LINE_DELIMITER; private boolean readJsonByLine = false; private boolean streamingPassthrough = false; @@ -104,8 +101,8 @@ public DorisStreamLoad(SparkSettings settings) { String[] dbTable = settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\."); this.db = dbTable[0]; this.tbl = dbTable[1]; - this.user = settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER); - this.passwd = settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD); + String user = settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER); + String passwd = settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD); this.authEncoded = getAuthEncoded(user, passwd); this.columns = settings.getProperty(ConfigurationOptions.DORIS_WRITE_FIELDS); this.maxFilterRatio = settings.getProperty(ConfigurationOptions.DORIS_MAX_FILTER_RATIO); @@ -143,9 +140,7 @@ private CloseableHttpClient getHttpClient() { private HttpPut getHttpPut(String label, String loadUrlStr, Boolean enable2PC) { HttpPut httpPut = new HttpPut(loadUrlStr); - httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded); - httpPut.setHeader(HttpHeaders.EXPECT, "100-continue"); - httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; charset=UTF-8"); + addCommonHeader(httpPut); httpPut.setHeader("label", label); if (StringUtils.isNotBlank(columns)) { httpPut.setHeader("columns", columns); @@ -167,6 +162,12 @@ public static class LoadResponse { public String respMsg; public String respContent; + public LoadResponse(HttpResponse response) throws IOException { + this.status = response.getStatusLine().getStatusCode(); + this.respMsg = response.getStatusLine().getReasonPhrase(); + this.respContent = EntityUtils.toString(new BufferedHttpEntity(response.getEntity()), StandardCharsets.UTF_8); + } + public LoadResponse(int status, String respMsg, String respContent) { this.status = status; this.respMsg = respMsg; @@ -180,7 +181,44 @@ public String toString() { } public int loadV2(List rows, String[] dfColumns, Boolean enable2PC) throws StreamLoadException, JsonProcessingException { - return load(parseLoadData(rows, dfColumns), enable2PC); + + String data = parseLoadData(rows, dfColumns); + + String label = generateLoadLabel(); + LoadResponse loadResponse; + try (CloseableHttpClient httpClient = getHttpClient()) { + String loadUrlStr = String.format(loadUrlPattern, getBackend(), db, tbl); + LOG.debug("Stream load Request:{} ,Body:{}", loadUrlStr, data); + // only to record the BE node in case of an exception + this.loadUrlStr = loadUrlStr; + HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC); + httpPut.setEntity(new InputStreamEntity(new RowInputStream(rows.iterator(), fileType, FIELD_DELIMITER, + LINE_DELIMITER, StandardCharsets.UTF_8))); + // httpPut.setEntity(new StringEntity(data, StandardCharsets.UTF_8)); + HttpResponse httpResponse = httpClient.execute(httpPut); + loadResponse = new LoadResponse(httpResponse); + } catch (IOException e) { + throw new RuntimeException(e); + } + + if (loadResponse.status != HttpStatus.SC_OK) { + LOG.info("Stream load Response HTTP Status Error:{}", loadResponse); + // throw new StreamLoadException("stream load error: " + loadResponse.respContent); + throw new StreamLoadException("stream load error"); + } else { + try { + RespContent respContent = MAPPER.readValue(loadResponse.respContent, RespContent.class); + if (!DORIS_SUCCESS_STATUS.contains(respContent.getStatus())) { + LOG.error("Stream load Response RES STATUS Error:{}", loadResponse); + throw new StreamLoadException("stream load error"); + } + LOG.info("Stream load Response:{}", loadResponse); + return respContent.getTxnId(); + } catch (IOException e) { + throw new StreamLoadException(e); + } + } + } public List loadStream(List> rows, String[] dfColumns, Boolean enable2PC) @@ -215,53 +253,6 @@ public List loadStream(List> rows, String[] dfColumns, Boo } - public int load(String value, Boolean enable2PC) throws StreamLoadException { - - String label = generateLoadLabel(); - - LoadResponse loadResponse; - int responseHttpStatus = -1; - try (CloseableHttpClient httpClient = getHttpClient()) { - String loadUrlStr = String.format(loadUrlPattern, getBackend(), db, tbl); - LOG.debug("Stream load Request:{} ,Body:{}", loadUrlStr, value); - // only to record the BE node in case of an exception - this.loadUrlStr = loadUrlStr; - - HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC); - httpPut.setEntity(new StringEntity(value, StandardCharsets.UTF_8)); - HttpResponse httpResponse = httpClient.execute(httpPut); - responseHttpStatus = httpResponse.getStatusLine().getStatusCode(); - String respMsg = httpResponse.getStatusLine().getReasonPhrase(); - String response = EntityUtils.toString(new BufferedHttpEntity(httpResponse.getEntity()), StandardCharsets.UTF_8); - loadResponse = new LoadResponse(responseHttpStatus, respMsg, response); - } catch (IOException e) { - e.printStackTrace(); - String err = "http request exception,load url : " + loadUrlStr + ",failed to execute spark stream load with label: " + label; - LOG.warn(err, e); - loadResponse = new LoadResponse(responseHttpStatus, e.getMessage(), err); - } - - if (loadResponse.status != HttpStatus.SC_OK) { - LOG.info("Stream load Response HTTP Status Error:{}", loadResponse); - // throw new StreamLoadException("stream load error: " + loadResponse.respContent); - throw new StreamLoadException("stream load error"); - } else { - ObjectMapper obj = new ObjectMapper(); - try { - RespContent respContent = obj.readValue(loadResponse.respContent, RespContent.class); - if (!DORIS_SUCCESS_STATUS.contains(respContent.getStatus())) { - LOG.error("Stream load Response RES STATUS Error:{}", loadResponse); - throw new StreamLoadException("stream load error"); - } - LOG.info("Stream load Response:{}", loadResponse); - return respContent.getTxnId(); - } catch (IOException e) { - throw new StreamLoadException(e); - } - } - - } - public void commit(int txnId) throws StreamLoadException { try (CloseableHttpClient client = getHttpClient()) { @@ -269,9 +260,7 @@ public void commit(int txnId) throws StreamLoadException { String backend = getBackend(); String abortUrl = String.format(abortUrlPattern, backend, db, tbl); HttpPut httpPut = new HttpPut(abortUrl); - httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded); - httpPut.setHeader(HttpHeaders.EXPECT, "100-continue"); - httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; charset=UTF-8"); + addCommonHeader(httpPut); httpPut.setHeader("txn_operation", "commit"); httpPut.setHeader("txn_id", String.valueOf(txnId)); @@ -289,10 +278,9 @@ public void commit(int txnId) throws StreamLoadException { throw new StreamLoadException("stream load error: " + reasonPhrase); } - ObjectMapper mapper = new ObjectMapper(); if (response.getEntity() != null) { String loadResult = EntityUtils.toString(response.getEntity()); - Map res = mapper.readValue(loadResult, new TypeReference>() { + Map res = MAPPER.readValue(loadResult, new TypeReference>() { }); if (res.get("status").equals("Fail") && !ResponseUtil.isCommitted(res.get("msg"))) { throw new StreamLoadException("Commit failed " + loadResult); @@ -314,9 +302,7 @@ public void abort(int txnId) throws StreamLoadException { try (CloseableHttpClient client = getHttpClient()) { String abortUrl = String.format(abortUrlPattern, getBackend(), db, tbl); HttpPut httpPut = new HttpPut(abortUrl); - httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded); - httpPut.setHeader(HttpHeaders.EXPECT, "100-continue"); - httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; charset=UTF-8"); + addCommonHeader(httpPut); httpPut.setHeader("txn_operation", "abort"); httpPut.setHeader("txn_id", String.valueOf(txnId)); @@ -327,9 +313,8 @@ public void abort(int txnId) throws StreamLoadException { throw new StreamLoadException("Fail to abort transaction " + txnId + " with url " + abortUrl); } - ObjectMapper mapper = new ObjectMapper(); String loadResult = EntityUtils.toString(response.getEntity()); - Map res = mapper.readValue(loadResult, new TypeReference>() { + Map res = MAPPER.readValue(loadResult, new TypeReference>() { }); if (!"Success".equals(res.get("status"))) { if (ResponseUtil.isCommitted(res.get("msg"))) { @@ -395,40 +380,17 @@ public List load(String key) throws Exception { private String parseLoadData(List rows, String[] dfColumns) throws StreamLoadException, JsonProcessingException { - List loadDataList; + if (dfColumns.length != rows.get(0).size()) { + return ""; + } switch (fileType.toUpperCase()) { - case "CSV": - loadDataList = Collections.singletonList( - rows.stream() - .map(row -> row.stream() - .map(DataUtil::handleColumnValue) - .map(Object::toString) - .collect(Collectors.joining(FIELD_DELIMITER)) - ).collect(Collectors.joining(LINE_DELIMITER))); - break; + return DataUtil.rowsToCsv(rows, FIELD_DELIMITER, LINE_DELIMITER); case "JSON": - List> dataList = new ArrayList<>(); - try { - for (Row row : rows) { - Map dataMap = new HashMap<>(); - if (dfColumns.length == row.size()) { - for (int i = 0; i < dfColumns.length; i++) { - dataMap.put(dfColumns[i], DataUtil.handleColumnValue(row.get(i))); - } - } - dataList.add(dataMap); - } - } catch (Exception e) { - throw new StreamLoadException("The number of configured columns does not match the number of data columns."); - } - // splits large collections to normal collection to avoid the "Requested array size exceeds VM limit" exception - loadDataList = ListUtils.getSerializedList(dataList, readJsonByLine ? LINE_DELIMITER : null); - break; + return DataUtil.rowsToJson(rows, dfColumns, readJsonByLine ? LINE_DELIMITER : null); default: throw new StreamLoadException(String.format("Unsupported file format in stream load: %s.", fileType)); - } } @@ -436,7 +398,10 @@ private String parseLoadData(List rows, String[] dfColumns) throws StreamLo private String generateLoadLabel() { Calendar calendar = Calendar.getInstance(); - return String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s", calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH), calendar.get(Calendar.HOUR_OF_DAY), calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND), UUID.randomUUID().toString().replaceAll("-", "")); + return String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s", + calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH), + calendar.get(Calendar.HOUR_OF_DAY), calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND), + UUID.randomUUID().toString().replaceAll("-", "")); } @@ -459,6 +424,12 @@ private String escapeString(String hexData) { return hexData; } + private void addCommonHeader(HttpRequestBase httpReq) { + httpReq.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded); + httpReq.setHeader(HttpHeaders.EXPECT, "100-continue"); + httpReq.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; charset=UTF-8"); + } + private void handleStreamPassThrough() { if ("json".equalsIgnoreCase(fileType)) { diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java new file mode 100644 index 00000000..97f6b888 --- /dev/null +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java @@ -0,0 +1,111 @@ +package org.apache.doris.spark.load; + +import org.apache.doris.spark.exception.DorisException; +import org.apache.doris.spark.exception.IllegalArgumentException; +import org.apache.doris.spark.util.DataUtil; + +import com.fasterxml.jackson.core.JsonProcessingException; +import org.apache.spark.sql.Row; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; + +public class RowInputStream extends InputStream { + + private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + private final Iterator iterator; + + private final String format; + + private final String seq; + + private final String delim; + + private final String[] columns; + + private CharBuffer current; + + private ByteBuffer pending; + + public RowInputStream(Iterator iterator, String format, String seq, String delim, String[] columns) { + this.iterator = iterator; + this.format = format; + this.seq = seq; + this.delim = delim; + this.columns = columns; + } + + @Override + public int read() throws IOException { + for(;;) { + if(pending != null && pending.hasRemaining()) + return pending.get() & 0xff; + if(!ensureCurrent()) return -1; + if(pending == null) pending = ByteBuffer.allocate(4096); + else pending.compact(); + DEFAULT_CHARSET.encode(current); + pending.flip(); + } + } + + private boolean ensureCurrent() { + while(current == null || !current.hasRemaining()) { + if(!iterator.hasNext()) return false; + current = CharBuffer.wrap(iterator.next().toString()); + } + return true; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + int transferred = 0; + if(pending != null && pending.hasRemaining()) { + boolean serveByBuffer = pending.remaining() >= len; + pending.get(b, off, transferred = Math.min(pending.remaining(), len)); + if(serveByBuffer) return transferred; + len -= transferred; + off += transferred; + } + ByteBuffer bb = ByteBuffer.wrap(b, off, len); + while(bb.hasRemaining() && ensureCurrent()) { + int r = bb.remaining(); + try { + bb.put(rowToByte(iterator.next())); + } catch (DorisException e) { + throw new IOException(e); + } + transferred += r - bb.remaining(); + } + return transferred == 0? -1: transferred; + } + + private byte[] rowToByte(Row row) throws DorisException { + + byte[] bytes; + + switch (format.toLowerCase()) { + case "csv": + bytes = DataUtil.rowToCsvBytes(row, seq); + break; + case "json": + try { + bytes = DataUtil.rowToJsonBytes(row, columns); + } catch (JsonProcessingException e) { + throw new DorisException("parse row to json bytes failed", e); + } + break; + default: + throw new IllegalArgumentException("format", format); + } + + return bytes; + + } + +} diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java index faa8ef58..3d66db52 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java @@ -17,19 +17,11 @@ package org.apache.doris.spark.serialization; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.math.BigDecimal; -import java.math.BigInteger; -import java.nio.charset.StandardCharsets; -import java.sql.Date; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.time.format.DateTimeFormatter; -import java.util.ArrayList; -import java.util.List; -import java.util.NoSuchElementException; +import org.apache.doris.sdk.thrift.TScanBatchResult; +import org.apache.doris.spark.exception.DorisException; +import org.apache.doris.spark.rest.models.Schema; +import com.google.common.base.Preconditions; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; @@ -47,17 +39,21 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.Types; - -import org.apache.doris.sdk.thrift.TScanBatchResult; -import org.apache.doris.spark.exception.DorisException; -import org.apache.doris.spark.rest.models.Schema; - import org.apache.commons.lang3.ArrayUtils; import org.apache.spark.sql.types.Decimal; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.google.common.base.Preconditions; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.sql.Date; +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; /** * row batch data container. @@ -128,7 +124,11 @@ public RowBatch(TScanBatchResult nextResult, Schema schema) throws DorisExceptio } public boolean hasNext() { - return offsetInRowBatch < readRowCount; + if (offsetInRowBatch >= readRowCount) { + rowBatch.clear(); + return false; + } + return true; } private void addValueToRow(int rowIndex, Object obj) { diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java index 58774474..6a5ba23a 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java @@ -17,15 +17,28 @@ package org.apache.doris.spark.util; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.Row; import scala.collection.JavaConversions; import scala.collection.mutable.WrappedArray; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; public class DataUtil { + private static final ObjectMapper MAPPER = new ObjectMapper(); + public static final String NULL_VALUE = "\\N"; public static Object handleColumnValue(Object value) { @@ -48,4 +61,76 @@ public static Object handleColumnValue(Object value) { } + public static String rowsToCsv(List rows, String sep, String lineDelimiter) { + StringBuilder builder = new StringBuilder(); + return rows.stream().map(row -> { + if (builder.length() != 0) { + builder.delete(0, builder.length()); + } + int n = row.size(); + if (n > 0) { + builder.append(handleColumnValue(row.get(0))); + int i = 1; + while (i < n) { + builder.append(sep); + builder.append(handleColumnValue(row.get(i))); + i++; + } + } + return builder.toString(); + }).collect(Collectors.joining(lineDelimiter)); + } + + public static String rowsToJson(List rows, String[] dfColumns, String lineDelimiter) + throws JsonProcessingException { + + List> batch = new LinkedList<>(); + for (Row row : rows) { + Map rowMap = new HashMap<>(row.size()); + for (int i = 0; i < dfColumns.length; i++) { + rowMap.put(dfColumns[i], handleColumnValue(row.get(i))); + } + batch.add(rowMap); + } + + // when lineDelimiter is null, use strip_outer_array mode, otherwise use json_by_line mode + if (lineDelimiter == null) { + return MAPPER.writeValueAsString(batch); + } else { + StringBuilder builder = new StringBuilder(); + for (Map data : batch) { + builder.append(MAPPER.writeValueAsString(data)).append(lineDelimiter); + } + int lastIdx = builder.lastIndexOf(lineDelimiter); + if (lastIdx != -1) { + return builder.substring(0, lastIdx); + } + return builder.toString(); + } + } + + public static byte[] rowToCsvBytes(Row row, String sep) { + StringBuilder builder = new StringBuilder(); + int n = row.size(); + if (n > 0) { + builder.append(handleColumnValue(row.get(0))); + int i = 1; + while (i < n) { + builder.append(sep); + builder.append(handleColumnValue(row.get(i))); + i++; + } + } + return builder.toString().getBytes(StandardCharsets.UTF_8); + } + + public static byte[] rowToJsonBytes(Row row, String[] dfColumns) + throws JsonProcessingException { + Map rowMap = new HashMap<>(row.size()); + for (int i = 0; i < dfColumns.length; i++) { + rowMap.put(dfColumns[i], handleColumnValue(row.get(i))); + } + return MAPPER.writeValueAsBytes(rowMap); + } + } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java index 5affb6f5..fbfab9a5 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java @@ -21,10 +21,10 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Lists; import org.apache.commons.lang3.exception.ExceptionUtils; -import org.apache.spark.sql.Row; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -33,12 +33,11 @@ public class ListUtils { private static final ObjectMapper MAPPER = new ObjectMapper(); - public static String getSerializedList(List> batch, + public static List getSerializedList(List> batch, String lineDelimiter) throws JsonProcessingException { - // List result = new ArrayList<>(); - // divideAndSerialize(batch, result, lineDelimiter); - // return result; - return generateSerializedResult(batch, lineDelimiter); + List result = new ArrayList<>(); + divideAndSerialize(batch, result, lineDelimiter); + return result; } /*** @@ -92,19 +91,4 @@ private static String generateSerializedResult(List> batch, } } - public static String mkString(Row row, String sep) { - StringBuilder builder = new StringBuilder(); - int n = row.size(); - if (n > 0) { - builder.append(row.get(0) == null ? "\\N" : row.get(0)); - int i = 1; - while (i < n) { - builder.append(sep); - builder.append(row.get(0) == null ? "\\N" : row.get(0)); - i++; - } - } - return builder.toString(); - } - } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/http/RequestBuilder.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/http/RequestBuilder.scala deleted file mode 100644 index 4ce3dd01..00000000 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/http/RequestBuilder.scala +++ /dev/null @@ -1,5 +0,0 @@ -package org.apache.doris.spark.http - -class RequestBuilder { - -} diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisStreamLoader.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisStreamLoader.scala deleted file mode 100644 index ee07e3af..00000000 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisStreamLoader.scala +++ /dev/null @@ -1,20 +0,0 @@ -package org.apache.doris.spark.writer - -import org.apache.doris.spark.cfg.SparkSettings - -class DorisStreamLoader(settings: SparkSettings) { - - def start(): Unit = Nil - - def load(rowData: Array[Byte]): Unit = { - - - } - - def stop(): Unit = Nil - - def commit(): Unit = Nil - - def abort(txnId: Long): Unit = Nil - -} diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala index 36acddfd..a4989114 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala @@ -30,7 +30,6 @@ import org.slf4j.{Logger, LoggerFactory} import java.io.IOException import java.time.Duration -import java.util import java.util.Objects import scala.collection.JavaConverters._ import scala.collection.mutable diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriterV2.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriterV2.scala deleted file mode 100644 index 9077b584..00000000 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriterV2.scala +++ /dev/null @@ -1,86 +0,0 @@ -package org.apache.doris.spark.writer - -import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings} -import org.apache.doris.spark.listener.DorisTransactionListener -import org.apache.doris.spark.sql.Utils -import org.apache.spark.sql.DataFrame -import org.slf4j.{Logger, LoggerFactory} - -import java.io.IOException -import java.time.Duration -import java.util -import java.util.Objects -import scala.collection.mutable -import scala.util.{Failure, Success, Try} - -class DorisWriterV2(settings: SparkSettings) { - - private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriterV2]) - - val batchSize: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE, - ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT) - - private val enable2PC: Boolean = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC, - ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT); - - private val dorisStreamLoader = new DorisStreamLoader(settings) - - def write(dataFrame: DataFrame): Unit = { - - val sc = dataFrame.sqlContext.sparkContext - val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc") - // if (enable2PC) { - // sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader)) - // } - - - val rowSerializer = RowSerializer(settings, dataFrame.columns) - - dataFrame.rdd - .map(rowSerializer.serialize) - .foreachPartition(flush) - - /** - * flush data to Doris and do retry when flush error - * - */ - def flush(rowIter: Iterator[Array[Byte]]): Unit = { - - // Try { - // dorisStreamLoader.start() - // rowIter.foreach(dorisStreamLoader.load) - // } match { - // case Success(txnIds) => { - // dorisStreamLoader.stop() - // if (enable2PC) txnIds.forEach(txnId => preCommittedTxnAcc.add(txnId)) - // } - // case Failure(e) => - // } - // - // Try(dorisStreamLoader.load(rowData)) match { - // case Success(txnIds) => if (enable2PC) txnIds.forEach(txnId => preCommittedTxnAcc.add(txnId)) - // case Failure(e) => - // if (enable2PC) { - // // if task run failed, acc value will not be returned to driver, - // // should abort all pre committed transactions inside the task - // logger.info("load task failed, start aborting previously pre-committed transactions") - // val abortFailedTxnIds = mutable.Buffer[Int]() - // preCommittedTxnAcc.value.forEach(txnId => { - // Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) { - // dorisStreamLoader.abort(txnId) - // } match { - // case Success(_) => - // case Failure(_) => abortFailedTxnIds += txnId - // } - // }) - // if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(",")) - // preCommittedTxnAcc.reset() - // } - // throw new IOException( - // s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) - // } - } - - } - -} diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/RowSerializer.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/RowSerializer.scala deleted file mode 100644 index 64806131..00000000 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/RowSerializer.scala +++ /dev/null @@ -1,76 +0,0 @@ -package org.apache.doris.spark.writer - -import com.fasterxml.jackson.databind.ObjectMapper -import org.apache.doris.spark.cfg.SparkSettings -import org.apache.spark.sql.Row - -import java.nio.charset.StandardCharsets -import java.sql.Timestamp - -class RowSerializer(format: String, - columns: Array[String], - columnSeparator: String - ) { - - private val mapper = new ObjectMapper(); - - def serialize(row: Row): Array[Byte] = { - if (columns.length != row.size) - return Array.empty[Byte] - format.toUpperCase match { - case "CSV" => toCsv(row) - case "JSON" => toJson(row) - case _ => throw new IllegalArgumentException("") - } - - } - - private def toCsv(row: Row): Array[Byte] = { - (0 to columns.length).map(i => row.get(i).toString).mkString(columnSeparator).getBytes(StandardCharsets.UTF_8) - } - - private def toJson(row: Row): Array[Byte] = { - mapper.writeValueAsBytes((0 to columns.length).map(i => { - val value = row.get(i) - if (value.isInstanceOf[Timestamp]) { - (columns(i), value.toString) - } else { - (columns(i), value) - } - }).toMap) - } - -} - -object RowSerializer { - - def apply(settings: SparkSettings, columns: Array[String]): RowSerializer = { - val format = settings.getProperty("format", "csv") - val columnSeparator = escapeString(settings.getProperty("column_separator", "\t")) - new RowSerializer(format, columns, columnSeparator) - } - - private def escapeString(hexData: String): String = { - if (hexData.startsWith("\\x") || hexData.startsWith("\\X")) { - try { - val data = hexData.substring(2) - val stringBuilder = new StringBuilder - var i = 0 - while (i < data.length) { - val hexByte = data.substring(i, i + 2) - val decimal = Integer.parseInt(hexByte, 16) - val character = decimal.toChar - stringBuilder.append(character) - - i += 2 - } - return stringBuilder.toString - } catch { - case e: Exception => - throw new RuntimeException("escape column_separator or line_delimiter error.{}", e) - } - } - hexData - } - -} From 0b7daedbd1d3bfac062018e476736bd562f3d39a Mon Sep 17 00:00:00 2001 From: gnehil Date: Mon, 28 Aug 2023 17:21:08 +0800 Subject: [PATCH 04/10] fix buffer read error --- .../doris/spark/load/DorisStreamLoad.java | 3 +- .../doris/spark/load/RowInputStream.java | 74 ++++++++++--------- 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index 710a71e8..99254a4d 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -41,7 +41,6 @@ import org.apache.http.client.methods.HttpRequestBase; import org.apache.http.entity.BufferedHttpEntity; import org.apache.http.entity.InputStreamEntity; -import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.util.EntityUtils; @@ -193,7 +192,7 @@ public int loadV2(List rows, String[] dfColumns, Boolean enable2PC) throws this.loadUrlStr = loadUrlStr; HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC); httpPut.setEntity(new InputStreamEntity(new RowInputStream(rows.iterator(), fileType, FIELD_DELIMITER, - LINE_DELIMITER, StandardCharsets.UTF_8))); + LINE_DELIMITER, dfColumns))); // httpPut.setEntity(new StringEntity(data, StandardCharsets.UTF_8)); HttpResponse httpResponse = httpClient.execute(httpPut); loadResponse = new LoadResponse(httpResponse); diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java index 97f6b888..004013ec 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java @@ -10,7 +10,6 @@ import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; -import java.nio.CharBuffer; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.Iterator; @@ -25,64 +24,67 @@ public class RowInputStream extends InputStream { private final String seq; - private final String delim; + private final byte[] delim; private final String[] columns; - private CharBuffer current; + private boolean isFirst = true; - private ByteBuffer pending; + private ByteBuffer buffer = ByteBuffer.allocate(0); public RowInputStream(Iterator iterator, String format, String seq, String delim, String[] columns) { this.iterator = iterator; this.format = format; this.seq = seq; - this.delim = delim; + this.delim = delim.getBytes(DEFAULT_CHARSET); this.columns = columns; } @Override public int read() throws IOException { - for(;;) { - if(pending != null && pending.hasRemaining()) - return pending.get() & 0xff; - if(!ensureCurrent()) return -1; - if(pending == null) pending = ByteBuffer.allocate(4096); - else pending.compact(); - DEFAULT_CHARSET.encode(current); - pending.flip(); + try { + if (buffer.remaining() == 0 && !readNext()) { + return -1; // End of stream + } + } catch (DorisException e) { + throw new IOException(e); } + return buffer.get() & 0xFF; } - private boolean ensureCurrent() { - while(current == null || !current.hasRemaining()) { - if(!iterator.hasNext()) return false; - current = CharBuffer.wrap(iterator.next().toString()); + @Override + public int read(byte[] b, int off, int len) throws IOException { + try { + if (buffer.remaining() == 0 && !readNext()) { + return -1; // End of stream + } + } catch (DorisException e) { + throw new IOException(e); } - return true; + int bytesRead = Math.min(len, buffer.remaining()); + buffer.get(b, off, bytesRead); + return bytesRead; } - @Override - public int read(byte[] b, int off, int len) throws IOException { - int transferred = 0; - if(pending != null && pending.hasRemaining()) { - boolean serveByBuffer = pending.remaining() >= len; - pending.get(b, off, transferred = Math.min(pending.remaining(), len)); - if(serveByBuffer) return transferred; - len -= transferred; - off += transferred; + public boolean readNext() throws DorisException { + if (!iterator.hasNext()) { + return false; } - ByteBuffer bb = ByteBuffer.wrap(b, off, len); - while(bb.hasRemaining() && ensureCurrent()) { - int r = bb.remaining(); - try { - bb.put(rowToByte(iterator.next())); - } catch (DorisException e) { - throw new IOException(e); + byte[] rowBytes = rowToByte(iterator.next()); + if (isFirst) { + buffer = ByteBuffer.wrap(rowBytes); + isFirst = false; + } else { + if (delim.length + rowBytes.length <= buffer.capacity()) { + buffer.clear(); + } else { + buffer = ByteBuffer.allocate(rowBytes.length + delim.length); } - transferred += r - bb.remaining(); + buffer.put(delim); + buffer.put(rowBytes); + buffer.flip(); } - return transferred == 0? -1: transferred; + return true; } private byte[] rowToByte(Row row) throws DorisException { From 28e29ad37b61678b1943a30b5666142bb340bd10 Mon Sep 17 00:00:00 2001 From: gnehil Date: Tue, 29 Aug 2023 14:44:14 +0800 Subject: [PATCH 05/10] optimize buffer expansion and add builder --- .../doris/spark/load/DorisStreamLoad.java | 24 ++--- .../doris/spark/load/RowInputStream.java | 102 ++++++++++++++++-- 2 files changed, 103 insertions(+), 23 deletions(-) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index 99254a4d..8488f734 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -111,14 +111,7 @@ public DorisStreamLoad(SparkSettings settings) { if ("csv".equals(fileType)) { FIELD_DELIMITER = escapeString(streamLoadProp.getOrDefault("column_separator", "\t")); } else if ("json".equalsIgnoreCase(fileType)) { - readJsonByLine = Boolean.parseBoolean(streamLoadProp.getOrDefault("read_json_by_line", "false")); - boolean stripOuterArray = Boolean.parseBoolean(streamLoadProp.getOrDefault("strip_outer_array", "false")); - if (readJsonByLine && stripOuterArray) { - throw new IllegalArgumentException("Only one of options 'read_json_by_line' and 'strip_outer_array' can be set to true"); - } else if (!readJsonByLine && !stripOuterArray) { - LOG.info("set default json mode: strip_outer_array"); - streamLoadProp.put("strip_outer_array", "true"); - } + streamLoadProp.put("read_json_by_line", "true"); } LINE_DELIMITER = escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n")); this.streamingPassthrough = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH, @@ -151,7 +144,11 @@ private HttpPut getHttpPut(String label, String loadUrlStr, Boolean enable2PC) { httpPut.setHeader("two_phase_commit", "true"); } if (MapUtils.isNotEmpty(streamLoadProp)) { - streamLoadProp.forEach(httpPut::setHeader); + streamLoadProp.forEach((k, v) -> { + if (!"strip_outer_array".equalsIgnoreCase(k)) { + httpPut.setHeader(k, v); + } + }); } return httpPut; } @@ -191,9 +188,12 @@ public int loadV2(List rows, String[] dfColumns, Boolean enable2PC) throws // only to record the BE node in case of an exception this.loadUrlStr = loadUrlStr; HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC); - httpPut.setEntity(new InputStreamEntity(new RowInputStream(rows.iterator(), fileType, FIELD_DELIMITER, - LINE_DELIMITER, dfColumns))); - // httpPut.setEntity(new StringEntity(data, StandardCharsets.UTF_8)); + RowInputStream rowInputStream = RowInputStream.newBuilder(rows.iterator()) + .format(fileType) + .sep(FIELD_DELIMITER) + .delim(LINE_DELIMITER) + .columns(dfColumns).build(); + httpPut.setEntity(new InputStreamEntity(rowInputStream)); HttpResponse httpResponse = httpClient.execute(httpPut); loadResponse = new LoadResponse(httpResponse); } catch (IOException e) { diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java index 004013ec..0063b02b 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java @@ -6,6 +6,8 @@ import com.fasterxml.jackson.core.JsonProcessingException; import org.apache.spark.sql.Row; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.InputStream; @@ -16,13 +18,17 @@ public class RowInputStream extends InputStream { + public static final Logger LOG = LoggerFactory.getLogger(RowInputStream.class); + private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + private static final int DEFAULT_BUF_SIZE = 4096; + private final Iterator iterator; private final String format; - private final String seq; + private final String sep; private final byte[] delim; @@ -32,11 +38,11 @@ public class RowInputStream extends InputStream { private ByteBuffer buffer = ByteBuffer.allocate(0); - public RowInputStream(Iterator iterator, String format, String seq, String delim, String[] columns) { + private RowInputStream(Iterator iterator, String format, String sep, byte[] delim, String[] columns) { this.iterator = iterator; this.format = format; - this.seq = seq; - this.delim = delim.getBytes(DEFAULT_CHARSET); + this.sep = sep; + this.delim = delim; this.columns = columns; } @@ -72,14 +78,12 @@ public boolean readNext() throws DorisException { } byte[] rowBytes = rowToByte(iterator.next()); if (isFirst) { - buffer = ByteBuffer.wrap(rowBytes); + ensureCapacity(rowBytes.length); + buffer.put(rowBytes); + buffer.flip(); isFirst = false; } else { - if (delim.length + rowBytes.length <= buffer.capacity()) { - buffer.clear(); - } else { - buffer = ByteBuffer.allocate(rowBytes.length + delim.length); - } + ensureCapacity(delim.length + rowBytes.length); buffer.put(delim); buffer.put(rowBytes); buffer.flip(); @@ -87,13 +91,42 @@ public boolean readNext() throws DorisException { return true; } + private void ensureCapacity(int need) { + + int capacity = buffer.capacity(); + + if (need <= capacity) { + buffer.clear(); + return; + } + + // need to extend + int newCapacity = calculateNewCapacity(capacity, need); + LOG.info("expand buffer, min cap: {}, now cap: {}, new cap: {}", need, capacity, newCapacity); + buffer = ByteBuffer.allocate(newCapacity); + + } + + private int calculateNewCapacity(int capacity, int minCapacity) { + int newCapacity; + if (capacity == 0) { + newCapacity = DEFAULT_BUF_SIZE; + while (newCapacity < minCapacity) { + newCapacity = newCapacity << 1; + } + } else { + newCapacity = capacity << 1; + } + return newCapacity; + } + private byte[] rowToByte(Row row) throws DorisException { byte[] bytes; switch (format.toLowerCase()) { case "csv": - bytes = DataUtil.rowToCsvBytes(row, seq); + bytes = DataUtil.rowToCsvBytes(row, sep); break; case "json": try { @@ -110,4 +143,51 @@ private byte[] rowToByte(Row row) throws DorisException { } + public static Builder newBuilder(Iterator rows) { + return new Builder(rows); + } + + public static class Builder { + + private final Iterator rows; + + private String format; + + private String sep; + + private byte[] delim; + + private String[] columns; + + private Builder(Iterator rows) { + this.rows = rows; + } + + public Builder format(String format) { + this.format = format; + return this; + } + + public Builder sep(String sep) { + this.sep = sep; + return this; + } + + public Builder delim(String delim) { + this.delim = delim.getBytes(DEFAULT_CHARSET); + return this; + } + + public Builder columns(String[] columns) { + this.columns = columns; + return this; + } + + public RowInputStream build() { + return new RowInputStream(rows, format, sep, delim, columns); + } + + } + + } From b603a0546975dff034ecf216e12895eaca0590e1 Mon Sep 17 00:00:00 2001 From: gnehil Date: Wed, 30 Aug 2023 18:26:25 +0800 Subject: [PATCH 06/10] optimize --- .../doris/spark/load/DorisStreamLoad.java | 41 +---- .../apache/doris/spark/load/RecordBatch.java | 135 ++++++++++++++++ ...tream.java => RecordBatchInputStream.java} | 151 +++++++++--------- .../org/apache/doris/spark/util/DataUtil.java | 59 +------ .../doris/spark/sql/ScalaDorisRow.scala | 6 +- .../doris/spark/writer/DorisWriter.scala | 13 +- .../apache/doris/spark/util/DataUtilTest.java | 2 +- 7 files changed, 228 insertions(+), 179 deletions(-) create mode 100644 spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java rename spark-doris-connector/src/main/java/org/apache/doris/spark/load/{RowInputStream.java => RecordBatchInputStream.java} (55%) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index 8488f734..b00dd357 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -22,7 +22,6 @@ import org.apache.doris.spark.rest.RestService; import org.apache.doris.spark.rest.models.BackendV2; import org.apache.doris.spark.rest.models.RespContent; -import org.apache.doris.spark.util.DataUtil; import org.apache.doris.spark.util.ResponseUtil; import com.fasterxml.jackson.core.JsonProcessingException; @@ -57,6 +56,7 @@ import java.util.Calendar; import java.util.Collections; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Properties; @@ -92,7 +92,6 @@ public class DorisStreamLoad implements Serializable { private final String fileType; private String FIELD_DELIMITER; private final String LINE_DELIMITER; - private boolean readJsonByLine = false; private boolean streamingPassthrough = false; @@ -164,36 +163,28 @@ public LoadResponse(HttpResponse response) throws IOException { this.respContent = EntityUtils.toString(new BufferedHttpEntity(response.getEntity()), StandardCharsets.UTF_8); } - public LoadResponse(int status, String respMsg, String respContent) { - this.status = status; - this.respMsg = respMsg; - this.respContent = respContent; - } - @Override public String toString() { return "status: " + status + ", resp msg: " + respMsg + ", resp content: " + respContent; } } - public int loadV2(List rows, String[] dfColumns, Boolean enable2PC) throws StreamLoadException, JsonProcessingException { - - String data = parseLoadData(rows, dfColumns); + public int load(Iterator rows, String[] dfColumns, Boolean enable2PC, int batchSize) + throws StreamLoadException, JsonProcessingException { String label = generateLoadLabel(); LoadResponse loadResponse; try (CloseableHttpClient httpClient = getHttpClient()) { String loadUrlStr = String.format(loadUrlPattern, getBackend(), db, tbl); - LOG.debug("Stream load Request:{} ,Body:{}", loadUrlStr, data); - // only to record the BE node in case of an exception this.loadUrlStr = loadUrlStr; HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC); - RowInputStream rowInputStream = RowInputStream.newBuilder(rows.iterator()) + RecordBatchInputStream recodeBatchInputStream = new RecordBatchInputStream(RecordBatch.newBuilder(rows) + .batchSize(batchSize) .format(fileType) .sep(FIELD_DELIMITER) .delim(LINE_DELIMITER) - .columns(dfColumns).build(); - httpPut.setEntity(new InputStreamEntity(rowInputStream)); + .columns(dfColumns).build()); + httpPut.setEntity(new InputStreamEntity(recodeBatchInputStream)); HttpResponse httpResponse = httpClient.execute(httpPut); loadResponse = new LoadResponse(httpResponse); } catch (IOException e) { @@ -202,7 +193,6 @@ public int loadV2(List rows, String[] dfColumns, Boolean enable2PC) throws if (loadResponse.status != HttpStatus.SC_OK) { LOG.info("Stream load Response HTTP Status Error:{}", loadResponse); - // throw new StreamLoadException("stream load error: " + loadResponse.respContent); throw new StreamLoadException("stream load error"); } else { try { @@ -377,23 +367,6 @@ public List load(String key) throws Exception { } - private String parseLoadData(List rows, String[] dfColumns) throws StreamLoadException, JsonProcessingException { - - if (dfColumns.length != rows.get(0).size()) { - return ""; - } - - switch (fileType.toUpperCase()) { - case "CSV": - return DataUtil.rowsToCsv(rows, FIELD_DELIMITER, LINE_DELIMITER); - case "JSON": - return DataUtil.rowsToJson(rows, dfColumns, readJsonByLine ? LINE_DELIMITER : null); - default: - throw new StreamLoadException(String.format("Unsupported file format in stream load: %s.", fileType)); - } - - } - private String generateLoadLabel() { Calendar calendar = Calendar.getInstance(); diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java new file mode 100644 index 00000000..9f280e44 --- /dev/null +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java @@ -0,0 +1,135 @@ +package org.apache.doris.spark.load; + +import org.apache.spark.sql.Row; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; + +/** + * Wrapper Object for batch loading + */ +public class RecordBatch { + + private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + + /** + * Spark row data iterator + */ + private final Iterator iterator; + + /** + * batch size for single load + */ + private final int batchSize; + + /** + * stream load format + */ + private final String format; + + /** + * column separator, only used when the format is csv + */ + private final String sep; + + /** + * line delimiter + */ + private final byte[] delim; + + /** + * column name array, only used when the format is json + */ + private final String[] columns; + + private RecordBatch(Iterator iterator, int batchSize, String format, String sep, byte[] delim, String[] columns) { + this.iterator = iterator; + this.batchSize = batchSize; + this.format = format; + this.sep = sep; + this.delim = delim; + this.columns = columns; + } + + public Iterator getIterator() { + return iterator; + } + + public int getBatchSize() { + return batchSize; + } + + public String getFormat() { + return format; + } + + public String getSep() { + return sep; + } + + public byte[] getDelim() { + return delim; + } + + public String[] getColumns() { + return columns; + } + + public static Builder newBuilder(Iterator iterator) { + return new Builder(iterator); + } + + /** + * RecordBatch Builder + */ + public static class Builder { + + private final Iterator iterator; + + private int batchSize; + + private String format; + + private String sep; + + private byte[] delim; + + private String[] columns; + + public Builder(Iterator iterator) { + this.iterator = iterator; + } + + public Builder batchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public Builder format(String format) { + this.format = format; + return this; + } + + public Builder sep(String sep) { + this.sep = sep; + return this; + } + + public Builder delim(String delim) { + this.delim = delim.getBytes(DEFAULT_CHARSET); + return this; + } + + public Builder columns(String[] columns) { + this.columns = columns; + return this; + } + + public RecordBatch build() { + return new RecordBatch(iterator, batchSize, format, sep, delim, columns); + } + + } + +} diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java similarity index 55% rename from spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java rename to spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java index 0063b02b..ab1251a1 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RowInputStream.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java @@ -2,6 +2,7 @@ import org.apache.doris.spark.exception.DorisException; import org.apache.doris.spark.exception.IllegalArgumentException; +import org.apache.doris.spark.exception.ShouldNeverHappenException; import org.apache.doris.spark.util.DataUtil; import com.fasterxml.jackson.core.JsonProcessingException; @@ -12,44 +13,45 @@ import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; import java.util.Iterator; -public class RowInputStream extends InputStream { +/** + * InputStream for batch load + */ +public class RecordBatchInputStream extends InputStream { - public static final Logger LOG = LoggerFactory.getLogger(RowInputStream.class); - - private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; + public static final Logger LOG = LoggerFactory.getLogger(RecordBatchInputStream.class); private static final int DEFAULT_BUF_SIZE = 4096; - private final Iterator iterator; - - private final String format; - - private final String sep; - - private final byte[] delim; - - private final String[] columns; + /** + * Load record batch + */ + private final RecordBatch recordBatch; + /** + * first line flag + */ private boolean isFirst = true; + /** + * record buffer + */ private ByteBuffer buffer = ByteBuffer.allocate(0); - private RowInputStream(Iterator iterator, String format, String sep, byte[] delim, String[] columns) { - this.iterator = iterator; - this.format = format; - this.sep = sep; - this.delim = delim; - this.columns = columns; + /** + * record count has been read + */ + private int readCount = 0; + + public RecordBatchInputStream(RecordBatch recordBatch) { + this.recordBatch = recordBatch; } @Override public int read() throws IOException { try { - if (buffer.remaining() == 0 && !readNext()) { + if (buffer.remaining() == 0 && endOfBatch()) { return -1; // End of stream } } catch (DorisException e) { @@ -61,7 +63,7 @@ public int read() throws IOException { @Override public int read(byte[] b, int off, int len) throws IOException { try { - if (buffer.remaining() == 0 && !readNext()) { + if (buffer.remaining() == 0 && endOfBatch()) { return -1; // End of stream } } catch (DorisException e) { @@ -72,10 +74,34 @@ public int read(byte[] b, int off, int len) throws IOException { return bytesRead; } - public boolean readNext() throws DorisException { + /** + * Check if the current batch read is over. + * If the number of reads is greater than or equal to the batch size or there is no next record, return false, + * otherwise return true. + * + * @return Whether the current batch read is over + * @throws DorisException + */ + public boolean endOfBatch() throws DorisException { + Iterator iterator = recordBatch.getIterator(); + if (readCount >= recordBatch.getBatchSize() || !iterator.hasNext()) { + return true; + } + readNext(iterator); + return false; + } + + /** + * read next record into buffer + * + * @param iterator row iterator + * @throws DorisException + */ + private void readNext(Iterator iterator) throws DorisException { if (!iterator.hasNext()) { - return false; + throw new ShouldNeverHappenException(); } + byte[] delim = recordBatch.getDelim(); byte[] rowBytes = rowToByte(iterator.next()); if (isFirst) { ensureCapacity(rowBytes.length); @@ -88,9 +114,14 @@ public boolean readNext() throws DorisException { buffer.put(rowBytes); buffer.flip(); } - return true; + readCount++; } + /** + * Check if the buffer has enough capacity. + * + * @param need required buffer space + */ private void ensureCapacity(int need) { int capacity = buffer.capacity(); @@ -107,6 +138,13 @@ private void ensureCapacity(int need) { } + /** + * Calculate new capacity for buffer expansion. + * + * @param capacity current buffer capacity + * @param minCapacity required min buffer space + * @return new capacity + */ private int calculateNewCapacity(int capacity, int minCapacity) { int newCapacity; if (capacity == 0) { @@ -120,74 +158,35 @@ private int calculateNewCapacity(int capacity, int minCapacity) { return newCapacity; } + /** + * Convert Spark row data to byte array + * + * @param row row data + * @return byte array + * @throws DorisException + */ private byte[] rowToByte(Row row) throws DorisException { byte[] bytes; - switch (format.toLowerCase()) { + switch (recordBatch.getFormat().toLowerCase()) { case "csv": - bytes = DataUtil.rowToCsvBytes(row, sep); + bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSep()); break; case "json": try { - bytes = DataUtil.rowToJsonBytes(row, columns); + bytes = DataUtil.rowToJsonBytes(row, recordBatch.getColumns()); } catch (JsonProcessingException e) { throw new DorisException("parse row to json bytes failed", e); } break; default: - throw new IllegalArgumentException("format", format); + throw new IllegalArgumentException("format", recordBatch.getFormat()); } return bytes; } - public static Builder newBuilder(Iterator rows) { - return new Builder(rows); - } - - public static class Builder { - - private final Iterator rows; - - private String format; - - private String sep; - - private byte[] delim; - - private String[] columns; - - private Builder(Iterator rows) { - this.rows = rows; - } - - public Builder format(String format) { - this.format = format; - return this; - } - - public Builder sep(String sep) { - this.sep = sep; - return this; - } - - public Builder delim(String delim) { - this.delim = delim.getBytes(DEFAULT_CHARSET); - return this; - } - - public Builder columns(String[] columns) { - this.columns = columns; - return this; - } - - public RowInputStream build() { - return new RowInputStream(rows, format, sep, delim, columns); - } - - } - } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java index 6a5ba23a..65d5a15e 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java @@ -19,21 +19,14 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.collect.ImmutableMap; import org.apache.spark.sql.Row; -import scala.collection.JavaConversions; import scala.collection.mutable.WrappedArray; -import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; -import java.util.Arrays; import java.util.HashMap; -import java.util.LinkedList; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; public class DataUtil { @@ -52,63 +45,13 @@ public static Object handleColumnValue(Object value) { } if (value instanceof WrappedArray) { - - Object[] arr = JavaConversions.seqAsJavaList((WrappedArray) value).toArray(); - return Arrays.toString(arr); + return String.format("[%s]", ((WrappedArray) value).mkString(",")); } return value; } - public static String rowsToCsv(List rows, String sep, String lineDelimiter) { - StringBuilder builder = new StringBuilder(); - return rows.stream().map(row -> { - if (builder.length() != 0) { - builder.delete(0, builder.length()); - } - int n = row.size(); - if (n > 0) { - builder.append(handleColumnValue(row.get(0))); - int i = 1; - while (i < n) { - builder.append(sep); - builder.append(handleColumnValue(row.get(i))); - i++; - } - } - return builder.toString(); - }).collect(Collectors.joining(lineDelimiter)); - } - - public static String rowsToJson(List rows, String[] dfColumns, String lineDelimiter) - throws JsonProcessingException { - - List> batch = new LinkedList<>(); - for (Row row : rows) { - Map rowMap = new HashMap<>(row.size()); - for (int i = 0; i < dfColumns.length; i++) { - rowMap.put(dfColumns[i], handleColumnValue(row.get(i))); - } - batch.add(rowMap); - } - - // when lineDelimiter is null, use strip_outer_array mode, otherwise use json_by_line mode - if (lineDelimiter == null) { - return MAPPER.writeValueAsString(batch); - } else { - StringBuilder builder = new StringBuilder(); - for (Map data : batch) { - builder.append(MAPPER.writeValueAsString(data)).append(lineDelimiter); - } - int lastIdx = builder.lastIndexOf(lineDelimiter); - if (lastIdx != -1) { - return builder.substring(0, lastIdx); - } - return builder.toString(); - } - } - public static byte[] rowToCsvBytes(Row row, String sep) { StringBuilder builder = new StringBuilder(); int n = row.size(); diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala index 06f5ca30..ec8f887a 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala @@ -27,7 +27,7 @@ private[spark] class ScalaDorisRow(rowOrder: Seq[String]) extends Row { /** No-arg constructor for Kryo serialization. */ def this() = this(null) - def iterator = values.iterator + def iterator: Iterator[Any] = values.iterator override def length: Int = values.length @@ -51,9 +51,9 @@ private[spark] class ScalaDorisRow(rowOrder: Seq[String]) extends Row { override def getByte(i: Int): Byte = getAs[Byte](i) - override def getString(i: Int): String = get(i).toString() + override def getString(i: Int): String = get(i).toString override def copy(): Row = this - override def toSeq = values.toSeq + override def toSeq: Seq[Any] = values } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala index a4989114..10699bb2 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala @@ -68,19 +68,18 @@ class DorisWriter(settings: SparkSettings) extends Serializable { resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize) } resultRdd - .foreachPartition(partition => { - partition - .grouped(batchSize) - .foreach(batch => flush(batch, dfColumns)) - }) + .foreachPartition(partition => + while (partition.hasNext) + loadBatch(partition, dfColumns, batchSize) + ) /** * flush data to Doris and do retry when flush error * */ - def flush(batch: Seq[Row], dfColumns: Array[String]): Unit = { + def loadBatch(batch: Iterator[Row], dfColumns: Array[String], batchSize: Int): Unit = { Utils.retry[Integer, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { - dorisStreamLoader.loadV2(batch.asJava, dfColumns, enable2PC) + dorisStreamLoader.load(batch.asJava, dfColumns, enable2PC, batchSize) } match { case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc) case Failure(e) => diff --git a/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java b/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java index 020a241c..0f6fb36b 100644 --- a/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java +++ b/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java @@ -27,6 +27,6 @@ public class DataUtilTest extends TestCase { public void testHandleColumnValue() { Assert.assertEquals("2023-08-14 18:00:00.0", DataUtil.handleColumnValue(Timestamp.valueOf("2023-08-14 18:00:00"))); - Assert.assertEquals("[1, 2, 3]", DataUtil.handleColumnValue(WrappedArray.make(new Integer[]{1,2,3}))); + Assert.assertEquals("[1,2,3]", DataUtil.handleColumnValue(WrappedArray.make(new Integer[]{1,2,3}))); } } \ No newline at end of file From 63ab2565e3637407bd3fd555d3553a7a330713ec Mon Sep 17 00:00:00 2001 From: gnehil Date: Thu, 7 Sep 2023 15:41:36 +0800 Subject: [PATCH 07/10] merge master and do some refract --- .../doris/spark/load/DorisStreamLoad.java | 47 +++------- .../apache/doris/spark/load/RecordBatch.java | 35 +++---- .../spark/load/RecordBatchInputStream.java | 34 +++++-- .../org/apache/doris/spark/util/DataUtil.java | 6 +- .../doris/spark/writer/DorisWriter.scala | 94 +++++-------------- 5 files changed, 88 insertions(+), 128 deletions(-) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index b00dd357..d9b1d516 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -35,6 +35,7 @@ import org.apache.http.HttpHeaders; import org.apache.http.HttpResponse; import org.apache.http.HttpStatus; +import org.apache.http.client.config.RequestConfig; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpPut; import org.apache.http.client.methods.HttpRequestBase; @@ -44,6 +45,9 @@ import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.util.EntityUtils; import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer; +import org.apache.spark.sql.types.StructType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -92,8 +96,9 @@ public class DorisStreamLoad implements Serializable { private final String fileType; private String FIELD_DELIMITER; private final String LINE_DELIMITER; - private boolean streamingPassthrough = false; + private final Integer batchSize; + private boolean enable2PC; public DorisStreamLoad(SparkSettings settings) { String[] dbTable = settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\."); @@ -115,6 +120,10 @@ public DorisStreamLoad(SparkSettings settings) { LINE_DELIMITER = escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n")); this.streamingPassthrough = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH, ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT); + this.batchSize = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE, + ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT); + this.enable2PC = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC, + ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT); } public String getLoadUrlStr() { @@ -169,7 +178,7 @@ public String toString() { } } - public int load(Iterator rows, String[] dfColumns, Boolean enable2PC, int batchSize) + public int load(Iterator rows, StructType schema, Deserializer deserializer) throws StreamLoadException, JsonProcessingException { String label = generateLoadLabel(); @@ -183,7 +192,7 @@ public int load(Iterator rows, String[] dfColumns, Boolean enable2PC, int b .format(fileType) .sep(FIELD_DELIMITER) .delim(LINE_DELIMITER) - .columns(dfColumns).build()); + .schema(schema).build(), deserializer, streamingPassthrough); httpPut.setEntity(new InputStreamEntity(recodeBatchInputStream)); HttpResponse httpResponse = httpClient.execute(httpPut); loadResponse = new LoadResponse(httpResponse); @@ -210,36 +219,12 @@ public int load(Iterator rows, String[] dfColumns, Boolean enable2PC, int b } - public List loadStream(List> rows, String[] dfColumns, Boolean enable2PC) + public Integer loadStream(Iterator rows, StructType schema, Deserializer deserializer) throws StreamLoadException, JsonProcessingException { - - List loadData; - if (this.streamingPassthrough) { handleStreamPassThrough(); - loadData = passthrough(rows); - } else { - loadData = parseLoadData(rows, dfColumns); } - - List txnIds = new ArrayList<>(loadData.size()); - - try { - for (String data : loadData) { - txnIds.add(load(data, enable2PC)); - } - } catch (StreamLoadException e) { - if (enable2PC && !txnIds.isEmpty()) { - LOG.error("load batch failed, abort previously pre-committed transactions"); - for (Integer txnId : txnIds) { - abort(txnId); - } - } - throw e; - } - - return txnIds; - + return load(rows, schema, deserializer); } public void commit(int txnId) throws StreamLoadException { @@ -412,8 +397,4 @@ private void handleStreamPassThrough() { } - private List passthrough(List> values) { - return values.stream().map(list -> list.get(0).toString()).collect(Collectors.toList()); - } - } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java index 9f280e44..caeb4c9a 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java @@ -1,6 +1,7 @@ package org.apache.doris.spark.load; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; @@ -16,7 +17,7 @@ public class RecordBatch { /** * Spark row data iterator */ - private final Iterator iterator; + private final Iterator iterator; /** * batch size for single load @@ -39,20 +40,21 @@ public class RecordBatch { private final byte[] delim; /** - * column name array, only used when the format is json + * schema of row */ - private final String[] columns; + private final StructType schema; - private RecordBatch(Iterator iterator, int batchSize, String format, String sep, byte[] delim, String[] columns) { + private RecordBatch(Iterator iterator, int batchSize, String format, String sep, byte[] delim, + StructType schema) { this.iterator = iterator; this.batchSize = batchSize; this.format = format; this.sep = sep; this.delim = delim; - this.columns = columns; + this.schema = schema; } - public Iterator getIterator() { + public Iterator getIterator() { return iterator; } @@ -72,11 +74,10 @@ public byte[] getDelim() { return delim; } - public String[] getColumns() { - return columns; + public StructType getSchema() { + return schema; } - - public static Builder newBuilder(Iterator iterator) { + public static Builder newBuilder(Iterator iterator) { return new Builder(iterator); } @@ -85,7 +86,7 @@ public static Builder newBuilder(Iterator iterator) { */ public static class Builder { - private final Iterator iterator; + private final Iterator iterator; private int batchSize; @@ -95,9 +96,9 @@ public static class Builder { private byte[] delim; - private String[] columns; + private StructType schema; - public Builder(Iterator iterator) { + public Builder(Iterator iterator) { this.iterator = iterator; } @@ -121,13 +122,13 @@ public Builder delim(String delim) { return this; } - public Builder columns(String[] columns) { - this.columns = columns; + public Builder schema(StructType schema) { + this.schema = schema; return this; } public RecordBatch build() { - return new RecordBatch(iterator, batchSize, format, sep, delim, columns); + return new RecordBatch(iterator, batchSize, format, sep, delim, schema); } } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java index ab1251a1..830f5d91 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java @@ -7,12 +7,15 @@ import com.fasterxml.jackson.core.JsonProcessingException; import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.Iterator; /** @@ -44,8 +47,20 @@ public class RecordBatchInputStream extends InputStream { */ private int readCount = 0; - public RecordBatchInputStream(RecordBatch recordBatch) { + /** + * streaming mode pass through data without process + */ + private final boolean passThrough; + + /** + * deserializer for converting InternalRow to Row + */ + private final ExpressionEncoder.Deserializer deserializer; + + public RecordBatchInputStream(RecordBatch recordBatch, ExpressionEncoder.Deserializer deserializer, boolean passThrough) { this.recordBatch = recordBatch; + this.deserializer = deserializer; + this.passThrough = passThrough; } @Override @@ -83,7 +98,7 @@ public int read(byte[] b, int off, int len) throws IOException { * @throws DorisException */ public boolean endOfBatch() throws DorisException { - Iterator iterator = recordBatch.getIterator(); + Iterator iterator = recordBatch.getIterator(); if (readCount >= recordBatch.getBatchSize() || !iterator.hasNext()) { return true; } @@ -97,7 +112,7 @@ public boolean endOfBatch() throws DorisException { * @param iterator row iterator * @throws DorisException */ - private void readNext(Iterator iterator) throws DorisException { + private void readNext(Iterator iterator) throws DorisException { if (!iterator.hasNext()) { throw new ShouldNeverHappenException(); } @@ -161,21 +176,28 @@ private int calculateNewCapacity(int capacity, int minCapacity) { /** * Convert Spark row data to byte array * - * @param row row data + * @param internalRow row data * @return byte array * @throws DorisException */ - private byte[] rowToByte(Row row) throws DorisException { + private byte[] rowToByte(InternalRow internalRow) throws DorisException { byte[] bytes; + Row row = deserializer.apply(internalRow.copy()); + + if (passThrough) { + bytes = row.getString(0).getBytes(StandardCharsets.UTF_8); + return bytes; + } + switch (recordBatch.getFormat().toLowerCase()) { case "csv": bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSep()); break; case "json": try { - bytes = DataUtil.rowToJsonBytes(row, recordBatch.getColumns()); + bytes = DataUtil.rowToJsonBytes(row, recordBatch.getSchema().fieldNames()); } catch (JsonProcessingException e) { throw new DorisException("parse row to json bytes failed", e); } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java index 65d5a15e..270266bd 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java @@ -67,11 +67,11 @@ public static byte[] rowToCsvBytes(Row row, String sep) { return builder.toString().getBytes(StandardCharsets.UTF_8); } - public static byte[] rowToJsonBytes(Row row, String[] dfColumns) + public static byte[] rowToJsonBytes(Row row, String[] columns) throws JsonProcessingException { Map rowMap = new HashMap<>(row.size()); - for (int i = 0; i < dfColumns.length; i++) { - rowMap.put(dfColumns[i], handleColumnValue(row.get(i))); + for (int i = 0; i < columns.length; i++) { + rowMap.put(columns[i], handleColumnValue(row.get(i))); } return MAPPER.writeValueAsBytes(rowMap); } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala index 10699bb2..9f9f99b3 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala @@ -22,14 +22,16 @@ import org.apache.doris.spark.listener.DorisTransactionListener import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad} import org.apache.doris.spark.sql.Utils import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.types.StructType import org.apache.spark.util.CollectionAccumulator import org.slf4j.{Logger, LoggerFactory} import java.io.IOException import java.time.Duration +import java.util import java.util.Objects import scala.collection.JavaConverters._ import scala.collection.mutable @@ -39,8 +41,6 @@ class DorisWriter(settings: SparkSettings) extends Serializable { private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriter]) - val batchSize: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE, - ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT) private val maxRetryTimes: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES, ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT) private val sinkTaskPartitionSize: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE) @@ -55,43 +55,14 @@ class DorisWriter(settings: SparkSettings) extends Serializable { private val dorisStreamLoader: DorisStreamLoad = CachedDorisStreamLoadClient.getOrCreate(settings) def write(dataFrame: DataFrame): Unit = { - - val sc = dataFrame.sqlContext.sparkContext - val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc") - if (enable2PC) { - sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader)) - } - - var resultRdd = dataFrame.rdd - val dfColumns = dataFrame.columns - if (Objects.nonNull(sinkTaskPartitionSize)) { - resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize) - } - resultRdd - .foreachPartition(partition => - while (partition.hasNext) - loadBatch(partition, dfColumns, batchSize) - ) - - /** - * flush data to Doris and do retry when flush error - * - */ - def loadBatch(batch: Iterator[Row], dfColumns: Array[String], batchSize: Int): Unit = { - Utils.retry[Integer, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { - dorisStreamLoader.load(batch.asJava, dfColumns, enable2PC, batchSize) - } match { - case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc) - case Failure(e) => - if (enable2PC) handleLoadFailure(preCommittedTxnAcc) - throw new IOException( - s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) - } - } - + doWrite(dataFrame, dorisStreamLoader.load) } def writeStream(dataFrame: DataFrame): Unit = { + doWrite(dataFrame, dorisStreamLoader.loadStream) + } + + private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType, Deserializer[Row]) => Int): Unit = { val sc = dataFrame.sqlContext.sparkContext val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc") @@ -101,47 +72,32 @@ class DorisWriter(settings: SparkSettings) extends Serializable { var resultRdd = dataFrame.queryExecution.toRdd val schema = dataFrame.schema - val dfColumns = dataFrame.columns + val deserializer = RowEncoder(schema).resolveAndBind().createDeserializer() if (Objects.nonNull(sinkTaskPartitionSize)) { resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize) } - resultRdd - .foreachPartition(partition => { - partition - .grouped(batchSize) - .foreach(batch => - flush(batch, dfColumns)) - }) - - /** - * flush data to Doris and do retry when flush error - * - */ - def flush(batch: Seq[InternalRow], dfColumns: Array[String]): Unit = { - Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { - dorisStreamLoader.loadStream(convertToObjectList(batch, schema), dfColumns, enable2PC) - } match { - case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc) - case Failure(e) => - if (enable2PC) handleLoadFailure(preCommittedTxnAcc) - throw new IOException( - s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) + resultRdd.foreachPartition(iterator => { + while (iterator.hasNext) { + // do load batch with retries + Utils.retry[Int, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { + loadFunc(iterator.asJava, schema, deserializer) + } match { + case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc) + case Failure(e) => + if (enable2PC) handleLoadFailure(preCommittedTxnAcc) + throw new IOException( + s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) + } } - } - - def convertToObjectList(rows: Seq[InternalRow], schema: StructType): util.List[util.List[Object]] = { - rows.map(row => { - row.toSeq(schema).map(_.asInstanceOf[AnyRef]).toList.asJava - }).asJava - } + }) } - private def handleLoadSuccess(txnIds: mutable.Buffer[Integer], acc: CollectionAccumulator[Int]): Unit = { - txnIds.foreach(txnId => acc.add(txnId)) + private def handleLoadSuccess(txnId: Int, acc: CollectionAccumulator[Int]): Unit = { + acc.add(txnId) } - def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = { + private def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = { // if task run failed, acc value will not be returned to driver, // should abort all pre committed transactions inside the task logger.info("load task failed, start aborting previously pre-committed transactions") From c53bd9ae0107902ab0f6bb6693c9e780adbcc61c Mon Sep 17 00:00:00 2001 From: gnehil Date: Thu, 7 Sep 2023 19:22:39 +0800 Subject: [PATCH 08/10] remove unused import --- .../main/java/org/apache/doris/spark/load/DorisStreamLoad.java | 1 - 1 file changed, 1 deletion(-) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index d9b1d516..23aab298 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -35,7 +35,6 @@ import org.apache.http.HttpHeaders; import org.apache.http.HttpResponse; import org.apache.http.HttpStatus; -import org.apache.http.client.config.RequestConfig; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpPut; import org.apache.http.client.methods.HttpRequestBase; From a4e99e57ca7528361888c6940ecde2c81abcd39e Mon Sep 17 00:00:00 2001 From: gnehil Date: Fri, 8 Sep 2023 18:18:40 +0800 Subject: [PATCH 09/10] convert internal row value manually --- .../doris/spark/load/DorisStreamLoad.java | 10 ++-- .../spark/load/RecordBatchInputStream.java | 20 ++----- .../org/apache/doris/spark/util/DataUtil.java | 45 +++++--------- .../apache/doris/spark/sql/SchemaUtils.scala | 58 ++++++++++++++++++- .../doris/spark/writer/DorisWriter.scala | 11 ++-- .../apache/doris/spark/util/DataUtilTest.java | 32 ---------- .../doris/spark/sql/SchemaUtilsTest.scala | 37 ++++++++++++ 7 files changed, 121 insertions(+), 92 deletions(-) delete mode 100644 spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java create mode 100644 spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index 23aab298..9ecfa405 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -43,9 +43,7 @@ import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.util.EntityUtils; -import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer; import org.apache.spark.sql.types.StructType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -177,7 +175,7 @@ public String toString() { } } - public int load(Iterator rows, StructType schema, Deserializer deserializer) + public int load(Iterator rows, StructType schema) throws StreamLoadException, JsonProcessingException { String label = generateLoadLabel(); @@ -191,7 +189,7 @@ public int load(Iterator rows, StructType schema, Deserializer .format(fileType) .sep(FIELD_DELIMITER) .delim(LINE_DELIMITER) - .schema(schema).build(), deserializer, streamingPassthrough); + .schema(schema).build(), streamingPassthrough); httpPut.setEntity(new InputStreamEntity(recodeBatchInputStream)); HttpResponse httpResponse = httpClient.execute(httpPut); loadResponse = new LoadResponse(httpResponse); @@ -218,12 +216,12 @@ public int load(Iterator rows, StructType schema, Deserializer } - public Integer loadStream(Iterator rows, StructType schema, Deserializer deserializer) + public Integer loadStream(Iterator rows, StructType schema) throws StreamLoadException, JsonProcessingException { if (this.streamingPassthrough) { handleStreamPassThrough(); } - return load(rows, schema, deserializer); + return load(rows, schema); } public void commit(int txnId) throws StreamLoadException { diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java index 830f5d91..6d7686fd 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java @@ -6,9 +6,7 @@ import org.apache.doris.spark.util.DataUtil; import com.fasterxml.jackson.core.JsonProcessingException; -import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,14 +50,8 @@ public class RecordBatchInputStream extends InputStream { */ private final boolean passThrough; - /** - * deserializer for converting InternalRow to Row - */ - private final ExpressionEncoder.Deserializer deserializer; - - public RecordBatchInputStream(RecordBatch recordBatch, ExpressionEncoder.Deserializer deserializer, boolean passThrough) { + public RecordBatchInputStream(RecordBatch recordBatch, boolean passThrough) { this.recordBatch = recordBatch; - this.deserializer = deserializer; this.passThrough = passThrough; } @@ -176,16 +168,14 @@ private int calculateNewCapacity(int capacity, int minCapacity) { /** * Convert Spark row data to byte array * - * @param internalRow row data + * @param row row data * @return byte array * @throws DorisException */ - private byte[] rowToByte(InternalRow internalRow) throws DorisException { + private byte[] rowToByte(InternalRow row) throws DorisException { byte[] bytes; - Row row = deserializer.apply(internalRow.copy()); - if (passThrough) { bytes = row.getString(0).getBytes(StandardCharsets.UTF_8); return bytes; @@ -193,11 +183,11 @@ private byte[] rowToByte(InternalRow internalRow) throws DorisException { switch (recordBatch.getFormat().toLowerCase()) { case "csv": - bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSep()); + bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSchema(), recordBatch.getSep()); break; case "json": try { - bytes = DataUtil.rowToJsonBytes(row, recordBatch.getSchema().fieldNames()); + bytes = DataUtil.rowToJsonBytes(row, recordBatch.getSchema()); } catch (JsonProcessingException e) { throw new DorisException("parse row to json bytes failed", e); } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java index 270266bd..aea6ddee 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java @@ -17,14 +17,15 @@ package org.apache.doris.spark.util; +import org.apache.doris.spark.sql.SchemaUtils; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import org.apache.spark.sql.Row; -import scala.collection.mutable.WrappedArray; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import java.nio.charset.StandardCharsets; -import java.sql.Date; -import java.sql.Timestamp; import java.util.HashMap; import java.util.Map; @@ -34,44 +35,28 @@ public class DataUtil { public static final String NULL_VALUE = "\\N"; - public static Object handleColumnValue(Object value) { - - if (value == null) { - return NULL_VALUE; - } - - if (value instanceof Date || value instanceof Timestamp) { - return value.toString(); - } - - if (value instanceof WrappedArray) { - return String.format("[%s]", ((WrappedArray) value).mkString(",")); - } - - return value; - - } - - public static byte[] rowToCsvBytes(Row row, String sep) { + public static byte[] rowToCsvBytes(InternalRow row, StructType schema, String sep) { StringBuilder builder = new StringBuilder(); - int n = row.size(); + StructField[] fields = schema.fields(); + int n = row.numFields(); if (n > 0) { - builder.append(handleColumnValue(row.get(0))); + builder.append(SchemaUtils.rowColumnValue(row, 0, fields[0].dataType())); int i = 1; while (i < n) { builder.append(sep); - builder.append(handleColumnValue(row.get(i))); + builder.append(SchemaUtils.rowColumnValue(row, i, fields[i].dataType())); i++; } } return builder.toString().getBytes(StandardCharsets.UTF_8); } - public static byte[] rowToJsonBytes(Row row, String[] columns) + public static byte[] rowToJsonBytes(InternalRow row, StructType schema) throws JsonProcessingException { - Map rowMap = new HashMap<>(row.size()); - for (int i = 0; i < columns.length; i++) { - rowMap.put(columns[i], handleColumnValue(row.get(i))); + StructField[] fields = schema.fields(); + Map rowMap = new HashMap<>(row.numFields()); + for (int i = 0; i < fields.length; i++) { + rowMap.put(fields[i].name(), SchemaUtils.rowColumnValue(row, i, fields[i].dataType())); } return MAPPER.writeValueAsBytes(rowMap); } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala index c8aa0349..f5a6a159 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala @@ -18,16 +18,23 @@ package org.apache.doris.spark.sql import org.apache.doris.sdk.thrift.TScanColumnDesc - -import scala.collection.JavaConversions._ +import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE, DORIS_READ_FIELD} import org.apache.doris.spark.cfg.Settings import org.apache.doris.spark.exception.DorisException import org.apache.doris.spark.rest.RestService import org.apache.doris.spark.rest.models.{Field, Schema} -import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE, DORIS_READ_FIELD} +import org.apache.doris.spark.util.DataUtil +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.slf4j.LoggerFactory +import java.sql.Timestamp +import java.time.{LocalDateTime, ZoneOffset} +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable + private[spark] object SchemaUtils { private val logger = LoggerFactory.getLogger(SchemaUtils.getClass.getSimpleName.stripSuffix("$")) @@ -137,4 +144,49 @@ private[spark] object SchemaUtils { tscanColumnDescs.foreach(desc => schema.put(new Field(desc.getName, desc.getType.name, "", 0, 0, ""))) schema } + + def rowColumnValue(row: SpecializedGetters, ordinal: Int, dataType: DataType): Any = { + + dataType match { + case NullType => DataUtil.NULL_VALUE + case BooleanType => row.getBoolean(ordinal) + case ByteType => row.getByte(ordinal) + case ShortType => row.getShort(ordinal) + case IntegerType => row.getInt(ordinal) + case LongType => row.getLong(ordinal) + case FloatType => row.getFloat(ordinal) + case DoubleType => row.getDouble(ordinal) + case StringType => row.getUTF8String(ordinal).toString + case TimestampType => + LocalDateTime.ofEpochSecond(row.getLong(ordinal) / 100000, (row.getLong(ordinal) % 1000).toInt, ZoneOffset.UTC) + new Timestamp(row.getLong(ordinal) / 1000).toString + case DateType => DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString + case BinaryType => row.getBinary(ordinal) + case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale) + case at: ArrayType => + val arrayData = row.getArray(ordinal) + var i = 0 + val buffer = mutable.Buffer[Any]() + while (i < arrayData.numElements()) { + if (arrayData.isNullAt(i)) buffer += null else buffer += rowColumnValue(arrayData, i, at.elementType) + i += 1 + } + s"[${buffer.mkString(",")}]" + case mt: MapType => + val mapData = row.getMap(ordinal) + val keys = mapData.keyArray() + val values = mapData.valueArray() + var i = 0 + val map = mutable.Map[Any, Any]() + while (i < keys.numElements()) { + map += rowColumnValue(keys, i, mt.keyType) -> rowColumnValue(values, i, mt.valueType) + i += 1 + } + map.toMap.asJava + case st: StructType => row.getStruct(ordinal, st.length) + case _ => throw new DorisException(s"Unsupported spark type: ${dataType.typeName}") + } + + } + } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala index 9f9f99b3..b278a385 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala @@ -21,10 +21,8 @@ import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings} import org.apache.doris.spark.listener.DorisTransactionListener import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad} import org.apache.doris.spark.sql.Utils -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer -import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.types.StructType import org.apache.spark.util.CollectionAccumulator import org.slf4j.{Logger, LoggerFactory} @@ -62,7 +60,9 @@ class DorisWriter(settings: SparkSettings) extends Serializable { doWrite(dataFrame, dorisStreamLoader.loadStream) } - private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType, Deserializer[Row]) => Int): Unit = { + private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType) => Int): Unit = { + + val sc = dataFrame.sqlContext.sparkContext val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc") @@ -72,7 +72,6 @@ class DorisWriter(settings: SparkSettings) extends Serializable { var resultRdd = dataFrame.queryExecution.toRdd val schema = dataFrame.schema - val deserializer = RowEncoder(schema).resolveAndBind().createDeserializer() if (Objects.nonNull(sinkTaskPartitionSize)) { resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize) } @@ -80,7 +79,7 @@ class DorisWriter(settings: SparkSettings) extends Serializable { while (iterator.hasNext) { // do load batch with retries Utils.retry[Int, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { - loadFunc(iterator.asJava, schema, deserializer) + loadFunc(iterator.asJava, schema) } match { case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc) case Failure(e) => diff --git a/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java b/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java deleted file mode 100644 index 0f6fb36b..00000000 --- a/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java +++ /dev/null @@ -1,32 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.spark.util; - -import junit.framework.TestCase; -import org.junit.Assert; -import scala.collection.mutable.WrappedArray; - -import java.sql.Timestamp; - -public class DataUtilTest extends TestCase { - - public void testHandleColumnValue() { - Assert.assertEquals("2023-08-14 18:00:00.0", DataUtil.handleColumnValue(Timestamp.valueOf("2023-08-14 18:00:00"))); - Assert.assertEquals("[1,2,3]", DataUtil.handleColumnValue(WrappedArray.make(new Integer[]{1,2,3}))); - } -} \ No newline at end of file diff --git a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala new file mode 100644 index 00000000..fb729cef --- /dev/null +++ b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala @@ -0,0 +1,37 @@ +package org.apache.doris.spark.sql + +import org.apache.spark.sql.SparkSession +import org.junit.{Assert, Ignore, Test} + +import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ + +@Ignore +class SchemaUtilsTest { + + @Test + def rowColumnValueTest(): Unit = { + + val spark = SparkSession.builder().master("local").getOrCreate() + + val df = spark.createDataFrame(Seq( + (1, Date.valueOf("2023-09-08"), Timestamp.valueOf("2023-09-08 17:00:00"), Array(1, 2, 3), Map[String, String]("a" -> "1")) + )).toDF("c1", "c2", "c3", "c4", "c5") + + val schema = df.schema + + df.queryExecution.toRdd.foreach(row => { + + val fields = schema.fields + Assert.assertEquals(1, SchemaUtils.rowColumnValue(row, 0, fields(0).dataType)) + Assert.assertEquals("2023-09-08", SchemaUtils.rowColumnValue(row, 1, fields(1).dataType)) + Assert.assertEquals("2023-09-08 17:00:00.0", SchemaUtils.rowColumnValue(row, 2, fields(2).dataType)) + Assert.assertEquals("[1,2,3]", SchemaUtils.rowColumnValue(row, 3, fields(3).dataType)) + println(SchemaUtils.rowColumnValue(row, 4, fields(4).dataType)) + Assert.assertEquals(Map("a" -> "1").asJava, SchemaUtils.rowColumnValue(row, 4, fields(4).dataType)) + + }) + + } + +} From 1976e5d5981537880b70773bc101c3ab94f237ba Mon Sep 17 00:00:00 2001 From: gnehil Date: Fri, 8 Sep 2023 19:30:41 +0800 Subject: [PATCH 10/10] add license header --- .../apache/doris/spark/load/RecordBatch.java | 17 +++++++++++++++++ .../spark/load/RecordBatchInputStream.java | 17 +++++++++++++++++ .../doris/spark/sql/SchemaUtilsTest.scala | 17 +++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java index caeb4c9a..779c057d 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + package org.apache.doris.spark.load; import org.apache.spark.sql.catalyst.InternalRow; diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java index 6d7686fd..9444c1da 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + package org.apache.doris.spark.load; import org.apache.doris.spark.exception.DorisException; diff --git a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala index fb729cef..e3868cbc 100644 --- a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala +++ b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + package org.apache.doris.spark.sql import org.apache.spark.sql.SparkSession