diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala index fc01d6b7..38abffe1 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala @@ -69,7 +69,7 @@ private[sql] class DorisSourceProvider extends DataSourceRegister } // accumulator for transaction handling - val acc = sqlContext.sparkContext.collectionAccumulator[CommitMessage]("BatchTxnAcc") + val acc = sqlContext.sparkContext.collectionAccumulator[(String, CommitMessage)]("BatchTxnAcc") // init stream loader val writer = new DorisWriter(sparkSettings, acc, false) writer.write(data) diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala index eb0ac126..78bbe848 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala @@ -31,7 +31,7 @@ private[sql] class DorisStreamLoadSink(sqlContext: SQLContext, settings: SparkSe @volatile private var latestBatchId = -1L // accumulator for transaction handling - private val acc = sqlContext.sparkContext.collectionAccumulator[CommitMessage]("StreamTxnAcc") + private val acc = sqlContext.sparkContext.collectionAccumulator[(String, CommitMessage)]("StreamTxnAcc") private val writer = new DorisWriter(settings, acc, true) // add listener for structured streaming @@ -41,7 +41,8 @@ private[sql] class DorisStreamLoadSink(sqlContext: SQLContext, settings: SparkSe if (batchId <= latestBatchId) { logger.info(s"Skipping already committed batch $batchId") } else { - writer.write(data) + val runId = sqlContext.streams.active.head.runId.toString + writer.write(data, Some(runId)) latestBatchId = batchId } } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTransactionListener.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTransactionListener.scala index 1d331906..1503a343 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTransactionListener.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTransactionListener.scala @@ -26,19 +26,28 @@ import org.apache.spark.util.CollectionAccumulator import scala.collection.JavaConverters._ import scala.collection.mutable -class DorisTransactionListener(txnAcc: CollectionAccumulator[CommitMessage], txnHandler: TransactionHandler) +class DorisTransactionListener(txnAcc: CollectionAccumulator[(String, CommitMessage)], txnHandler: TransactionHandler) extends SparkListener with Logging { + private val jobToStages = mutable.HashMap[Int, List[Int]]() + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobToStages += jobStart.jobId -> jobStart.stageIds.toList + } + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala + val stageIds = jobToStages.get(jobEnd.jobId) + val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala.filter(item => { + stageIds.nonEmpty && stageIds.get.contains(item._1.toInt) + }).map(_._2) jobEnd.jobResult match { // if job succeed, commit all transactions case JobSucceeded => if (messages.isEmpty) { - log.debug("job run succeed, but there is no pre-committed txn ids") + log.debug(s"job ${jobEnd.jobId} run succeed, but there is no pre-committed txn ids") return } - log.info("job run succeed, start committing transactions") + log.info(s"job ${jobEnd.jobId} run succeed, start committing transactions") try txnHandler.commitTransactions(messages.toList) catch { case e: Exception => throw e @@ -48,10 +57,10 @@ class DorisTransactionListener(txnAcc: CollectionAccumulator[CommitMessage], txn // if job failed, abort all pre committed transactions case _ => if (messages.isEmpty) { - log.debug("job run failed, but there is no pre-committed txn ids") + log.debug(s"job ${jobEnd.jobId} run failed, but there is no pre-committed txn ids") return } - log.info("job run failed, start aborting transactions") + log.info(s"job ${jobEnd.jobId} run failed, start aborting transactions") try txnHandler.abortTransactions(messages.toList) catch { case e: Exception => throw e diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTxnStreamingQueryListener.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTxnStreamingQueryListener.scala index 0ddd4efc..f81d13a4 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTxnStreamingQueryListener.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTxnStreamingQueryListener.scala @@ -24,17 +24,23 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.CollectionAccumulator +import java.util.UUID import scala.collection.JavaConverters._ import scala.collection.mutable -class DorisTxnStreamingQueryListener(txnAcc: CollectionAccumulator[CommitMessage], txnHandler: TransactionHandler) +class DorisTxnStreamingQueryListener(txnAcc: CollectionAccumulator[(String, CommitMessage)], txnHandler: TransactionHandler) extends StreamingQueryListener with Logging { - override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {} + override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { + + } override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = { + // do commit transaction when each batch ends val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala + .filter(item => UUID.fromString(item._1) equals event.progress.runId) + .map(_._2) if (messages.isEmpty) { log.warn("job run succeed, but there is no pre-committed txn") return @@ -49,7 +55,7 @@ class DorisTxnStreamingQueryListener(txnAcc: CollectionAccumulator[CommitMessage override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = { - val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala + val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala.map(_._2) // if job failed, abort all pre committed transactions if (event.exception.nonEmpty) { if (messages.isEmpty) { diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala index 9886b526..10988732 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala @@ -22,6 +22,7 @@ import org.apache.doris.spark.load.{CommitMessage, CopyIntoLoader, Loader, Strea import org.apache.doris.spark.sql.Utils import org.apache.doris.spark.txn.TransactionHandler import org.apache.doris.spark.txn.listener.DorisTransactionListener +import org.apache.spark.TaskContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType @@ -35,13 +36,13 @@ import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Success} class DorisWriter(settings: SparkSettings, - txnAcc: CollectionAccumulator[CommitMessage], + txnAcc: CollectionAccumulator[(String, CommitMessage)], isStreaming: Boolean) extends Serializable { private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriter]) private val sinkTaskPartitionSize: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE) - private val loadMode: String = settings.getProperty(ConfigurationOptions.LOAD_MODE,ConfigurationOptions.DEFAULT_LOAD_MODE) + private val loadMode: String = settings.getProperty(ConfigurationOptions.LOAD_MODE, ConfigurationOptions.DEFAULT_LOAD_MODE) private val sinkTaskUseRepartition: Boolean = settings.getProperty(ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION, ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION_DEFAULT.toString).toBoolean @@ -74,11 +75,11 @@ class DorisWriter(settings: SparkSettings, * * @param dataFrame source dataframe */ - def write(dataFrame: DataFrame): Unit = { - doWrite(dataFrame, loader.load) + def write(dataFrame: DataFrame, id: Option[String] = None): Unit = { + doWrite(dataFrame, loader.load, id) } - private def doWrite(dataFrame: DataFrame, loadFunc: (Iterator[InternalRow], StructType) => Option[CommitMessage]): Unit = { + private def doWrite(dataFrame: DataFrame, loadFunc: (Iterator[InternalRow], StructType) => Option[CommitMessage], id: Option[String]): Unit = { // do not add spark listener when job is streaming mode if (enable2PC && !isStreaming) { dataFrame.sparkSession.sparkContext.addSparkListener(new DorisTransactionListener(txnAcc, txnHandler)) @@ -92,12 +93,13 @@ class DorisWriter(settings: SparkSettings, val schema = resultDataFrame.schema resultRdd.foreachPartition(iterator => { + val accId = if (id.isEmpty) TaskContext.get().stageId().toString else id.get while (iterator.hasNext) { val batchIterator = new BatchIterator(iterator, batchSize, maxRetryTimes > 0) val retry = Utils.retry[Option[CommitMessage], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) _ retry(loadFunc(batchIterator, schema))(batchIterator.reset()) match { case Success(msg) => - if (enable2PC) handleLoadSuccess(msg, txnAcc) + if (enable2PC) handleLoadSuccess(accId, msg, txnAcc) batchIterator.close() case Failure(e) => if (enable2PC) handleLoadFailure(txnAcc) @@ -110,11 +112,11 @@ class DorisWriter(settings: SparkSettings, } - private def handleLoadSuccess(msg: Option[CommitMessage], acc: CollectionAccumulator[CommitMessage]): Unit = { - acc.add(msg.get) + private def handleLoadSuccess(id: String, msg: Option[CommitMessage], acc: CollectionAccumulator[(String, CommitMessage)]): Unit = { + acc.add((id, msg.get)) } - private def handleLoadFailure(acc: CollectionAccumulator[CommitMessage]): Unit = { + private def handleLoadFailure(acc: CollectionAccumulator[(String, CommitMessage)]): Unit = { // if task run failed, acc value will not be returned to driver, // should abort all pre committed transactions inside the task logger.info("load task failed, start aborting previously pre-committed transactions") @@ -123,7 +125,7 @@ class DorisWriter(settings: SparkSettings, return } - try txnHandler.abortTransactions(acc.value.asScala.toList) + try txnHandler.abortTransactions(acc.value.asScala.map(_._2).toList) catch { case e: Exception => throw e }