Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] fix multi job transaction commit trigger conflict #217

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ private[sql] class DorisSourceProvider extends DataSourceRegister
}

// accumulator for transaction handling
val acc = sqlContext.sparkContext.collectionAccumulator[CommitMessage]("BatchTxnAcc")
val acc = sqlContext.sparkContext.collectionAccumulator[(String, CommitMessage)]("BatchTxnAcc")
// init stream loader
val writer = new DorisWriter(sparkSettings, acc, false)
writer.write(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ private[sql] class DorisStreamLoadSink(sqlContext: SQLContext, settings: SparkSe
@volatile private var latestBatchId = -1L

// accumulator for transaction handling
private val acc = sqlContext.sparkContext.collectionAccumulator[CommitMessage]("StreamTxnAcc")
private val acc = sqlContext.sparkContext.collectionAccumulator[(String, CommitMessage)]("StreamTxnAcc")
private val writer = new DorisWriter(settings, acc, true)

// add listener for structured streaming
Expand All @@ -41,7 +41,8 @@ private[sql] class DorisStreamLoadSink(sqlContext: SQLContext, settings: SparkSe
if (batchId <= latestBatchId) {
logger.info(s"Skipping already committed batch $batchId")
} else {
writer.write(data)
val runId = sqlContext.streams.active.head.runId.toString
writer.write(data, Some(runId))
latestBatchId = batchId
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,28 @@ import org.apache.spark.util.CollectionAccumulator
import scala.collection.JavaConverters._
import scala.collection.mutable

class DorisTransactionListener(txnAcc: CollectionAccumulator[CommitMessage], txnHandler: TransactionHandler)
class DorisTransactionListener(txnAcc: CollectionAccumulator[(String, CommitMessage)], txnHandler: TransactionHandler)
extends SparkListener with Logging {

private val jobToStages = mutable.HashMap[Int, List[Int]]()

override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
jobToStages += jobStart.jobId -> jobStart.stageIds.toList
}

override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala
val stageIds = jobToStages.get(jobEnd.jobId)
val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala.filter(item => {
stageIds.nonEmpty && stageIds.get.contains(item._1.toInt)
}).map(_._2)
jobEnd.jobResult match {
// if job succeed, commit all transactions
case JobSucceeded =>
if (messages.isEmpty) {
log.debug("job run succeed, but there is no pre-committed txn ids")
log.debug(s"job ${jobEnd.jobId} run succeed, but there is no pre-committed txn ids")
return
}
log.info("job run succeed, start committing transactions")
log.info(s"job ${jobEnd.jobId} run succeed, start committing transactions")
try txnHandler.commitTransactions(messages.toList)
catch {
case e: Exception => throw e
Expand All @@ -48,10 +57,10 @@ class DorisTransactionListener(txnAcc: CollectionAccumulator[CommitMessage], txn
// if job failed, abort all pre committed transactions
case _ =>
if (messages.isEmpty) {
log.debug("job run failed, but there is no pre-committed txn ids")
log.debug(s"job ${jobEnd.jobId} run failed, but there is no pre-committed txn ids")
return
}
log.info("job run failed, start aborting transactions")
log.info(s"job ${jobEnd.jobId} run failed, start aborting transactions")
try txnHandler.abortTransactions(messages.toList)
catch {
case e: Exception => throw e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,23 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.CollectionAccumulator

import java.util.UUID
import scala.collection.JavaConverters._
import scala.collection.mutable

class DorisTxnStreamingQueryListener(txnAcc: CollectionAccumulator[CommitMessage], txnHandler: TransactionHandler)
class DorisTxnStreamingQueryListener(txnAcc: CollectionAccumulator[(String, CommitMessage)], txnHandler: TransactionHandler)
extends StreamingQueryListener with Logging {

override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {}
override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {

}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {

// do commit transaction when each batch ends
val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala
.filter(item => UUID.fromString(item._1) equals event.progress.runId)
.map(_._2)
if (messages.isEmpty) {
log.warn("job run succeed, but there is no pre-committed txn")
return
Expand All @@ -49,7 +55,7 @@ class DorisTxnStreamingQueryListener(txnAcc: CollectionAccumulator[CommitMessage


override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala
val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala.map(_._2)
// if job failed, abort all pre committed transactions
if (event.exception.nonEmpty) {
if (messages.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.doris.spark.load.{CommitMessage, CopyIntoLoader, Loader, Strea
import org.apache.doris.spark.sql.Utils
import org.apache.doris.spark.txn.TransactionHandler
import org.apache.doris.spark.txn.listener.DorisTransactionListener
import org.apache.spark.TaskContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
Expand All @@ -35,13 +36,13 @@ import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Success}

class DorisWriter(settings: SparkSettings,
txnAcc: CollectionAccumulator[CommitMessage],
txnAcc: CollectionAccumulator[(String, CommitMessage)],
isStreaming: Boolean) extends Serializable {

private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriter])

private val sinkTaskPartitionSize: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE)
private val loadMode: String = settings.getProperty(ConfigurationOptions.LOAD_MODE,ConfigurationOptions.DEFAULT_LOAD_MODE)
private val loadMode: String = settings.getProperty(ConfigurationOptions.LOAD_MODE, ConfigurationOptions.DEFAULT_LOAD_MODE)
private val sinkTaskUseRepartition: Boolean = settings.getProperty(ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION,
ConfigurationOptions.DORIS_SINK_TASK_USE_REPARTITION_DEFAULT.toString).toBoolean

Expand Down Expand Up @@ -74,11 +75,11 @@ class DorisWriter(settings: SparkSettings,
*
* @param dataFrame source dataframe
*/
def write(dataFrame: DataFrame): Unit = {
doWrite(dataFrame, loader.load)
def write(dataFrame: DataFrame, id: Option[String] = None): Unit = {
doWrite(dataFrame, loader.load, id)
}

private def doWrite(dataFrame: DataFrame, loadFunc: (Iterator[InternalRow], StructType) => Option[CommitMessage]): Unit = {
private def doWrite(dataFrame: DataFrame, loadFunc: (Iterator[InternalRow], StructType) => Option[CommitMessage], id: Option[String]): Unit = {
// do not add spark listener when job is streaming mode
if (enable2PC && !isStreaming) {
dataFrame.sparkSession.sparkContext.addSparkListener(new DorisTransactionListener(txnAcc, txnHandler))
Expand All @@ -92,12 +93,13 @@ class DorisWriter(settings: SparkSettings,
val schema = resultDataFrame.schema

resultRdd.foreachPartition(iterator => {
val accId = if (id.isEmpty) TaskContext.get().stageId().toString else id.get
while (iterator.hasNext) {
val batchIterator = new BatchIterator(iterator, batchSize, maxRetryTimes > 0)
val retry = Utils.retry[Option[CommitMessage], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) _
retry(loadFunc(batchIterator, schema))(batchIterator.reset()) match {
case Success(msg) =>
if (enable2PC) handleLoadSuccess(msg, txnAcc)
if (enable2PC) handleLoadSuccess(accId, msg, txnAcc)
batchIterator.close()
case Failure(e) =>
if (enable2PC) handleLoadFailure(txnAcc)
Expand All @@ -110,11 +112,11 @@ class DorisWriter(settings: SparkSettings,

}

private def handleLoadSuccess(msg: Option[CommitMessage], acc: CollectionAccumulator[CommitMessage]): Unit = {
acc.add(msg.get)
private def handleLoadSuccess(id: String, msg: Option[CommitMessage], acc: CollectionAccumulator[(String, CommitMessage)]): Unit = {
acc.add((id, msg.get))
}

private def handleLoadFailure(acc: CollectionAccumulator[CommitMessage]): Unit = {
private def handleLoadFailure(acc: CollectionAccumulator[(String, CommitMessage)]): Unit = {
// if task run failed, acc value will not be returned to driver,
// should abort all pre committed transactions inside the task
logger.info("load task failed, start aborting previously pre-committed transactions")
Expand All @@ -123,7 +125,7 @@ class DorisWriter(settings: SparkSettings,
return
}

try txnHandler.abortTransactions(acc.value.asScala.toList)
try txnHandler.abortTransactions(acc.value.asScala.map(_._2).toList)
catch {
case e: Exception => throw e
}
Expand Down
Loading