Skip to content

Commit

Permalink
fix: Execute exchange tasks as new coroutines.
Browse files Browse the repository at this point in the history
  • Loading branch information
robinsons committed Dec 12, 2024
1 parent 68496d3 commit e03ff9c
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 210 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,15 @@
package org.wfanet.panelmatch.client.deploy

import java.time.Clock
import kotlinx.coroutines.Job
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.isActive
import kotlinx.coroutines.job
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.wfanet.measurement.common.logAndSuppressExceptionSuspend
import org.wfanet.measurement.common.throttler.Throttler
import org.wfanet.panelmatch.client.common.Identity
import org.wfanet.panelmatch.client.exchangetasks.ExchangeTaskMapper
import org.wfanet.panelmatch.client.launcher.ApiClient
import org.wfanet.panelmatch.client.launcher.ExchangeStepLauncher
import org.wfanet.panelmatch.client.launcher.ExchangeStepValidatorImpl
import org.wfanet.panelmatch.client.launcher.ExchangeTaskExecutor
import org.wfanet.panelmatch.client.storage.PrivateStorageSelector
Expand All @@ -36,6 +34,7 @@ import org.wfanet.panelmatch.client.storage.StorageDetailsProvider
import org.wfanet.panelmatch.common.ExchangeDateKey
import org.wfanet.panelmatch.common.Timeout
import org.wfanet.panelmatch.common.certificates.CertificateManager
import org.wfanet.panelmatch.common.loggerFor
import org.wfanet.panelmatch.common.secrets.SecretMap
import org.wfanet.panelmatch.common.storage.StorageFactory

Expand Down Expand Up @@ -112,35 +111,66 @@ abstract class ExchangeWorkflowDaemon : Runnable {
override fun run() = runBlocking { runSuspending() }

suspend fun runSuspending() {
val exchangeStepLauncher =
ExchangeStepLauncher(apiClient = apiClient, taskLauncher = stepExecutor)
when (runMode) {
RunMode.DAEMON -> runDaemon(exchangeStepLauncher)
RunMode.CRON_JOB -> runCronJob(exchangeStepLauncher)
RunMode.DAEMON -> runDaemon()
RunMode.CRON_JOB -> runCronJob()
}
}

/** Runs [exchangeStepLauncher] in an infinite loop. */
protected open suspend fun runDaemon(exchangeStepLauncher: ExchangeStepLauncher) {
/**
* Claims exchange steps and executes them in an infinite loop. Claimed steps are launched as
* child coroutines to allow multiple steps to execute concurrently.
*/
protected open suspend fun runDaemon() = coroutineScope {
throttler.loopOnReady {
// All errors thrown inside the loop should be suppressed such that the daemon doesn't crash.
logAndSuppressExceptionSuspend { exchangeStepLauncher.findAndRunExchangeStep() }
val step =
try {
apiClient.claimExchangeStep()
} catch (e: Exception) {
logger.severe("Failed to claim exchange step: $e")
null
}

if (step != null) {
launch {
try {
stepExecutor.execute(step)
} catch (e: Exception) {
logger.severe("Failed to execute exchange step: $e")
}
}
}
}
}

/**
* Runs [exchangeStepLauncher] in a loop until there are no remaining tasks and all launched tasks
* have completed.
* Claims exchange steps and executes them until all available steps are exhausted, then returns.
* Claimed steps are launched as child coroutines to allow multiple steps to execute concurrently.
*/
protected open suspend fun runCronJob(exchangeStepLauncher: ExchangeStepLauncher) {
val activeJobs = mutableListOf<Job>()
protected open suspend fun runCronJob() = coroutineScope {
do {
activeJobs.removeIf { !it.isActive }
val job = logAndSuppressExceptionSuspend { exchangeStepLauncher.findAndRunExchangeStep() }
if (job != null) {
activeJobs += job
throttler.onReady {
val step =
try {
apiClient.claimExchangeStep()
} catch (e: Exception) {
logger.severe("Failed to claim exchange step: $e")
null
}

if (step != null) {
launch {
try {
stepExecutor.execute(step)
} catch (e: Exception) {
logger.severe("Failed to execute exchange step: $e")
}
}
}
}
} while (currentCoroutineContext().isActive && activeJobs.isNotEmpty())
} while (coroutineContext.isActive && coroutineContext.job.children.any { it.isActive })

logger.info("All available steps executed; shutting down.")
}

enum class RunMode {
Expand All @@ -153,4 +183,8 @@ abstract class ExchangeWorkflowDaemon : Runnable {
*/
CRON_JOB,
}

companion object {
private val logger by loggerFor()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@

package org.wfanet.panelmatch.client.launcher

import kotlinx.coroutines.Job

/** Executes [ApiClient.ClaimedExchangeStep]s. */
interface ExchangeStepExecutor {
/** Executes [exchangeStep] in a new coroutine and returns the running [Job]. */
suspend fun execute(exchangeStep: ApiClient.ClaimedExchangeStep): Job
/** Executes [exchangeStep]. */
suspend fun execute(exchangeStep: ApiClient.ClaimedExchangeStep)
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,9 @@ package org.wfanet.panelmatch.client.launcher
import com.google.protobuf.ByteString
import java.util.logging.Level
import java.util.logging.Logger
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.wfanet.measurement.storage.StorageClient
import org.wfanet.measurement.storage.StorageClient.Blob
import org.wfanet.panelmatch.client.common.ExchangeContext
Expand Down Expand Up @@ -54,12 +49,11 @@ class ExchangeTaskExecutor(
private val privateStorageSelector: PrivateStorageSelector,
private val exchangeTaskMapper: ExchangeTaskMapper,
private val validator: ExchangeStepValidator,
private val dispatcher: CoroutineDispatcher = Dispatchers.Default,
) : ExchangeStepExecutor {

override suspend fun execute(exchangeStep: ApiClient.ClaimedExchangeStep): Job = coroutineScope {
override suspend fun execute(exchangeStep: ApiClient.ClaimedExchangeStep) {
val attemptKey = exchangeStep.attemptKey
launch(dispatcher + CoroutineName(attemptKey.toString()) + TaskLog(attemptKey.toString())) {
withContext(CoroutineName(attemptKey.toString()) + TaskLog(attemptKey.toString())) {
try {
val validatedStep = validator.validate(exchangeStep)
val context =
Expand All @@ -79,7 +73,6 @@ class ExchangeTaskExecutor(
else -> ExchangeStepAttempt.State.FAILED
}
markAsFinished(attemptKey, attemptState)
cancel("Task failed and reported back to Kingdom. Cancelling task scope.", e)
}
}
}
Expand Down
22 changes: 0 additions & 22 deletions src/test/kotlin/org/wfanet/panelmatch/client/launcher/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,6 @@ kt_jvm_test(
],
)

kt_jvm_test(
name = "ExchangeStepLauncherTest",
timeout = "short",
srcs = ["ExchangeStepLauncherTest.kt"],
test_class = "org.wfanet.panelmatch.client.launcher.ExchangeStepLauncherTest",
deps = [
"//src/main/kotlin/org/wfanet/panelmatch/client/common",
"//src/main/kotlin/org/wfanet/panelmatch/client/launcher",
"//src/main/kotlin/org/wfanet/panelmatch/common",
"//src/main/kotlin/org/wfanet/panelmatch/common/testing",
"//src/main/proto/wfa/panelmatch/client/internal:exchange_workflow_kt_jvm_proto",
"@wfa_common_jvm//imports/java/com/google/common/truth",
"@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto",
"@wfa_common_jvm//imports/java/org/junit",
"@wfa_common_jvm//imports/java/org/mockito",
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core",
"@wfa_common_jvm//imports/kotlin/org/mockito/kotlin",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
"@wfa_measurement_system//src/main/kotlin/org/wfanet/measurement/common/api:resource_key",
],
)

kt_jvm_test(
name = "ExchangeStepValidatorImplTest",
timeout = "short",
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import com.google.protobuf.ByteString
import com.google.protobuf.kotlin.toByteStringUtf8
import java.time.LocalDate
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.job
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
Expand Down Expand Up @@ -109,8 +108,7 @@ class ExchangeTaskExecutorTest {
prepareBlob("some-blob")
whenever(validator.validate(any())).thenReturn(VALIDATED_EXCHANGE_STEP)

val job = exchangeTaskExecutor.execute(EXCHANGE_STEP)
job.join()
exchangeTaskExecutor.execute(EXCHANGE_STEP)

assertThat(testPrivateStorageSelector.storageClient.getBlob("c")?.toStringUtf8())
.isEqualTo("Out:commutative-deterministic-encrypt-some-blob")
Expand All @@ -121,8 +119,7 @@ class ExchangeTaskExecutorTest {
timeout.expired = true
whenever(validator.validate(any())).thenReturn(VALIDATED_EXCHANGE_STEP)

val job = exchangeTaskExecutor.execute(EXCHANGE_STEP)
job.join()
exchangeTaskExecutor.execute(EXCHANGE_STEP)

assertThat(testPrivateStorageSelector.storageClient.getBlob("c")).isNull()
}
Expand All @@ -135,8 +132,7 @@ class ExchangeTaskExecutorTest {
val exchangeTaskExecutor =
createExchangeTaskExecutor(FakeExchangeTaskMapper(::TransientThrowingExchangeTask))

val job = exchangeTaskExecutor.execute(EXCHANGE_STEP)
job.join()
exchangeTaskExecutor.execute(EXCHANGE_STEP)

verify(apiClient).finishExchangeStepAttempt(eq(ATTEMPT_KEY), eq(State.FAILED), any())
}
Expand All @@ -149,8 +145,7 @@ class ExchangeTaskExecutorTest {
val exchangeTaskExecutor =
createExchangeTaskExecutor(FakeExchangeTaskMapper(::PermanentThrowingExchangeTask))

val job = exchangeTaskExecutor.execute(EXCHANGE_STEP)
job.join()
exchangeTaskExecutor.execute(EXCHANGE_STEP)

verify(apiClient).finishExchangeStepAttempt(eq(ATTEMPT_KEY), eq(State.FAILED_STEP), any())
}
Expand Down

0 comments on commit e03ff9c

Please sign in to comment.